[发明专利]基于中间层特征提取增强的知识蒸馏实现图像分类的方法有效
申请号: | 202110887562.2 | 申请日: | 2021-08-03 |
公开(公告)号: | CN113610146B | 公开(公告)日: | 2023-08-04 |
发明(设计)人: | 陈泽仁;徐琪;张天魁;钟炜强 | 申请(专利权)人: | 江西鑫铂瑞科技有限公司 |
主分类号: | G06V10/764 | 分类号: | G06V10/764;G06V10/75;G06N5/02;G06N3/0464;G06N3/084 |
代理公司: | 温州名创知识产权代理有限公司 33258 | 代理人: | 程嘉炜 |
地址: | 335000 江西省鹰*** | 国省代码: | 江西;36 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 基于 中间层 特征 提取 增强 知识 蒸馏 实现 图像 分类 方法 | ||
1.一种基于中间层特征提取增强的知识蒸馏实现图像分类的方法,其特征在于,所述方法包括以下步骤:
获取待分类图像;
将所述待分类图像导入预先训练好的教师-学生网络中,得到相应的分类结果;其中,所述预先训练好的教师-学生网络是基于历史图像分别输入教师模型和学生模型中,并采用预设的跨层非局部模块分别提取学生模型和教师模型的多尺度像素间关系,且待计算出教师模型和学生模型间的多尺度像素间关系蒸馏损失之后,将蒸馏损失加入学生模型的损失函数中,进一步根据损失函数反向传播更新学生模型参数直至学生模型收敛,将收敛后的学生模型作为优化模型输出进行训练得到的;
所述跨层非局部模块采用如下公式进行计算:
R=(Xq,Xr1,…,Xrn)=Xq+∑Zri
其中,Xq为查询层特征;Xri为响应层i特征;Zri为响应层i与查询层的像素间关系,表示为为卷积运算;θ(·),和g(·)均为可学习嵌入式函数,使用1×1卷积实现;θ(Xq),gi(Xri)为可学习嵌入函数对输入的特征图做预处理,计算单个像素的表示;f(·,·为二维函数,使用点积实现;为计算对应位置像素间的相关程度;
所述跨层非局部模块提取学生模型或教师模型的多尺度像素间关系的具体步骤如下:
将历史图像作为学生模型或教师模型的输入,并输入相应模型的第一层;
若第一层是选定的响应层,将第一层的输出特征作为响应层输入其后的跨层非局部模块,并将跨层非局部模块的输出特征输入其后的第二层;或若第一层是选定的查询层,将第一层的输出特征作为查询层输入其后的跨层非局部模块;
用第二层更新第一层;
若第一层是最后一层,将最后一层的输出特征作为预测结果并输出。
2.如权利要求1所述的基于中间层特征提取增强的知识蒸馏实现图像分类的方法,其特征在于,计算教师模型和学生模型间的多尺度像素间关系蒸馏损失时,采用L2范式损失的形式如下:
L蒸馏=L2(RT,M(RS))
其中,M(RS)是可学习的匹配函数,使教师模型和学生模型的多尺度关系特征图在维度和尺寸上匹配;RS首先通过一个卷积层c(·),然后再通过一个上采样函数h(·)进行匹配,即M(RS)=h(c(RS))。
3.如权利要求1所述的基于中间层特征提取增强的知识蒸馏实现图像分类的方法,其特征在于,将蒸馏损失加入到学生模型的损失函数中时,采用如下公式进行运算:
L总=L分类+αL蒸馏
其中,L总为总损失函数;L分类为分类损失函数;L蒸馏为蒸馏损失函数;α为蒸馏损失函数在总损失函数中占的比例系数。
4.如权利要求3所述的基于中间层特征提取增强的知识蒸馏实现图像分类的方法,其特征在于,所述分类损失函数采用交叉熵形式计算,具体公式如下:
其中,y为图像的真实分类标签,为学生模型输出的预测结果。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于江西鑫铂瑞科技有限公司,未经江西鑫铂瑞科技有限公司许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202110887562.2/1.html,转载请声明来源钻瓜专利网。
- 上一篇:一种传感器元件生产工艺及加工机构
- 下一篇:一种马蹄加工用便于晾晒的收集机