[发明专利]基于桥接知识蒸馏卷积神经网络的图像分类方法在审
| 申请号: | 202110107120.1 | 申请日: | 2021-01-27 |
| 公开(公告)号: | CN112784964A | 公开(公告)日: | 2021-05-11 |
| 发明(设计)人: | 杜兰;王震;宋佳伦 | 申请(专利权)人: | 西安电子科技大学 |
| 主分类号: | G06N3/04 | 分类号: | G06N3/04;G06N3/08;G06K9/62 |
| 代理公司: | 陕西电子工业专利中心 61205 | 代理人: | 田文英;王品华 |
| 地址: | 710071*** | 国省代码: | 陕西;61 |
| 权利要求书: | 查看更多 | 说明书: | 查看更多 |
| 摘要: | |||
| 搜索关键词: | 基于 知识 蒸馏 卷积 神经网络 图像 分类 方法 | ||
1.一种基于桥接知识蒸馏卷积神经网络的图像分类方法,其特征在于,在教师网络与学生网络之间构建桥接结构,根据KL散度损失函数与交叉熵损失函数训练学生网络,该方法包括如下步骤:
(1)构建教师网络与学生网络:
(1a)搭建结构相同的14层的教师网络和14层的学生网络,其结构依次为:输入层,第一卷积层,第一激活层,第一最大池化层,第二卷积层,第二激活层,第二最大池化层,第三卷积层,第三激活层,第三最大池化层,第四卷积层,第四激活层,第五卷积层,输出层;
(1b)设置教师网络各层参数如下:
将第一至第五卷积层特征映射图数目分别设置为16、32、64、128、10,卷积核大小分别设置为5×5、5×5、6×6、5×5、3×3;
将第一至第三最大池化层的池化窗口均设置为2×2,步长均设置为2;
将第一至第四激活层的激活函数均设置为ReLU激活函数;
(1c)设置学生网络各层参数如下:
将第一至第五卷积层特征映射图数目分别设置为9、10、31、8、10,卷积核大小分别设置为5×5、5×5、6×6、5×5、3×3;
将第一至第三最大池化层的池化窗口均设置为2×2,步长均设置为2;
将第一至第四激活层的激活函数均设置为ReLU激活函数;
(2)生成训练集:
选取至少为2种类别、每种类别至少为200个图像组成训练集;
(3)训练教师网络:
将训练集输入到教师网络中,得到每张训练图像的预测类别概率,利用交叉熵损失函数,计算每张图像的预测类别概率与该图像对应的类别标签间的损失,通过反向传播算法迭代更新教师网络参数,直到交叉熵损失函数收敛为止,得到训练好的教师网络;
(4)构建桥接结构:
将训练好的教师网络的第四卷积层与学生网络的第四卷积层相连后得到桥接结构;
(5)训练学生网络:
(5a)将训练集同时输入到学生网络、训练好的教师网络中,得到学生网络的输出,教师网络的输出,以及桥接结构的输出;
(5b)利用KL散度损失函数,计算教师网络的输出与桥接结构的输出之间的KL散度损失值;
(5c)利用交叉熵损失函数,计算学生网络的输出与训练图像的类别标签之间的交叉熵损失值;
(5d)将KL散度损失值与交叉熵损失值之和作为总损失值,通过反向传播算法迭代更新学生网络的参数,直到总损失值收敛为止,得到训练好的学生网络;
(6)对待分类图像进行分类:
将待分类图像输入到训练好的学生网络中,得到学生网络对于待分类图像的预测类别概率,选择预测类别概率中值最高的概率所对应的类别作为对该图像的分类结果。
2.根据权利要求1所述的基于桥接知识蒸馏卷积神经网络的图像分类方法,其特征在于,步骤(3)、步骤(5c)中所述的交叉熵损失函数如下:
其中,J表示交叉熵损失函数,N表示训练集中图像的总数,Σ表示求和操作,i表示训练集中图像的序号,Yi表示训练集中第i张图像对应的类别标签,log表示以2为底的对数操作,Pi表示将训练集中第i张图像输入网络中得到的预测类别概率。
3.根据权利要求2所述的基于桥接知识蒸馏卷积神经网络的图像分类方法,其特征在于,步骤(5b)中所述的KL散度损失函数如下:
其中,表示KL散度损失函数,Qi表示将训练集中第i张图像输入教师网络中得到的预测类别概率,Bi表示将训练集中第i张图像输入桥接结构中得到的预测类别概率。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于西安电子科技大学,未经西安电子科技大学许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202110107120.1/1.html,转载请声明来源钻瓜专利网。





