[发明专利]基于知识蒸馏的模型训练方法、装置、设备及存储介质在审
申请号: | 202210816261.5 | 申请日: | 2022-07-12 |
公开(公告)号: | CN115062769A | 公开(公告)日: | 2022-09-16 |
发明(设计)人: | 张楠;王健宗;瞿晓阳 | 申请(专利权)人: | 平安科技(深圳)有限公司 |
主分类号: | G06N3/04 | 分类号: | G06N3/04;G06N3/08 |
代理公司: | 深圳众鼎专利商标代理事务所(普通合伙) 44325 | 代理人: | 姚章国 |
地址: | 518000 广东省深圳市福田区福*** | 国省代码: | 广东;44 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 基于 知识 蒸馏 模型 训练 方法 装置 设备 存储 介质 | ||
1.一种基于知识蒸馏的模型训练方法,其特征在于,包括:
获取满足目标条件的第一模型和不满足所述目标条件的第二模型,所述第一模型包括M个网络层,所述第二模型包括N个网络层,N、M均为大于零的整数;
根据所述第一模型的输出构建的优化损失函数,更新所述第二模型的初始损失函数,得到更新后的第二模型;
计算所述第一模型中M个网络层分别与所述更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度;
根据所述目标相似度,构建相似度损失函数,并将所述相似度损失函数与所述优化损失函数的和作为目标损失函数;
使用训练集对所述第二模型进行训练,直至所述目标损失函数收敛,得到满足所述目标条件的第二模型。
2.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述根据所述第一模型的输出构建的优化损失函数,更新所述第二模型的初始损失函数,得到更新后的第二模型,包括:
将带有原始标签的第一训练样本输入至所述第一模型中,以所述第一模型输出的新标签更新所述第一训练样本对应的原始标签,得到第二训练样本;
利用所述第一训练样本与所述第二训练样本,分别对第二模型进行训练,得到第一知识蒸馏损失函数与第二知识蒸馏损失函数;
通过所述第一知识蒸馏损失函数与所述第二知识蒸馏损失函数,构建优化损失函数;
根据所述优化损失函数,更新所述第二模型的初始损失函数,得到更新后的第二模型。
3.如权利要求2所述的基于知识蒸馏的模型训练方法,其特征在于,所述通过所述第一知识蒸馏损失函数与所述第二知识蒸馏损失函数,构建优化损失函数,包括:
对所述第一知识蒸馏损失函数与所述第二知识蒸馏损失函数设置不同的初始参数,得到初始蒸馏损失函数;
使用梯度下降算法对所述初始蒸馏损失函数进行参数更新,得到目标参数,使用所述目标参数更新初始蒸馏损失函数,得到优化损失函数。
4.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述计算所述第一模型中M个网络层分别与所述更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度,包括:
分别获取所述第一模型中M个网络层与所述更新后的第二模型中N个网络层中每个网络层的特征矩阵;
计算所述第一模型中M个网络层的特征矩阵分别与所述更新后的第二模型中N个网络层的特征矩阵的表征相似度,得到所述第一模型中每个网络层对应的表征相似度序列;
从所述第一模型中每个网络层对应的表征相似度序列中通过预设选取条件,确定目标相似度。
5.如权利要求4所述的基于知识蒸馏的模型训练方法,其特征在于,所述从所述第一模型中每个网络层对应的表征相似度序列中通过预设选取条件,确定目标相似度,包括:
从所述第一模型中每个网络层对应的表征相似度序列中获取表征相似度最大值,将所述表征相似度最大值作为目标相似度。
6.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述根据所述目标相似度,构建相似度损失函数,包括:
根据所述目标相似度,计算所述第二模型中每个网络层的损失值;
基于所述第二模型中每个网络层的损失值,构建相似度损失函数。
7.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述使用训练集对所述第二模型进行训练,直至所述目标损失函数收敛,得到满足所述目标条件的第二模型,包括:
根据所述训练集中的正负样本,构建样本对;所述样本对至少包括一个正样本与一个负样本;
基于所述样本对,对所述第二模型进行训练,直至所述目标损失函数收敛,得到满足所述目标条件的第二模型。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于平安科技(深圳)有限公司,未经平安科技(深圳)有限公司许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202210816261.5/1.html,转载请声明来源钻瓜专利网。