[发明专利]一种机器学习方法以及装置在审
申请号: | 202011375881.7 | 申请日: | 2020-11-30 |
公开(公告)号: | CN112801265A | 公开(公告)日: | 2021-05-14 |
发明(设计)人: | 杨扩;叶翰嘉;洪蓝青;胡海林 | 申请(专利权)人: | 华为技术有限公司 |
主分类号: | G06N3/04 | 分类号: | G06N3/04;G06N3/08;G06K9/62 |
代理公司: | 深圳市深佳知识产权代理事务所(普通合伙) 44285 | 代理人: | 陈松浩 |
地址: | 518129 广东*** | 国省代码: | 广东;44 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 机器 学习方法 以及 装置 | ||
本申请公开了人工智能领域的一种机器学习方法以及装置,用于在小样本学习中,通过构建数据集对应的概念空间,基于样本在空间中的度量来训练模型,得到输出精度更高的模型。该方法包括:支撑集和查询集中样本包括的多种类型的信息对应多个维度的概念空间;目标神经网络中包括元模型和概念空间映射模型,对目标神经网络的任意一次更新过程包括:元模型提取支撑样本以及查询样本的特征得到第一特征向量和第二特征向量;概念空间映射模型将元模型输出的特征向量映射至概念空间,并计算概念空间中查询样本和支撑样本之间的距离;基于该距离得到查询样本的预测标签,随后计算损失值并更新目标神经网络,得到当前次迭代更新后的目标神经网络。
技术领域
本申请涉及人工智能领域,尤其涉及一种机器学习方法以及装置。
背景技术
现有机器学习算法在样本量充足的情况下,有着十分优秀的表现。但在实际应用中,由于人工标注成本高、标注不可得等问题,开始关注小样本学习(Few-shot Learning)的解决方案——要求机器学习算法在训练样本有限的情况下,也可以给出合理的预测结果。
元学习(Meta-learning)是小样本学习问题的一种解决范式。元学习通过在样本充足的训练集中随机采样大量与目标小样本任务相似的任务,训练一个有较好泛化性能的元模型 (meta-model),该元模型在目标任务的少量训练样本上进行学习,最终得到适合该目标小样本任务的预测模型。
然而,在进行小样本学习时,对元模型的每次更新基于当前小样本任务涉及的类别,可能因相同或者相似样本之间的区别,或者,因相同或者不同类别的样本之间的区别等,导致训练得到的模型输出精度较低。
发明内容
本申请提供一种机器学习方法以及装置,用于在小样本学习中,通过构建数据集对应的概念空间,基于样本在空间中的度量来训练模型,得到输出精度更高的模型。
有鉴于此,第一方面,本申请提供一种机器学习方法,其特征在于,包括:
获取支撑集和查询集,支撑集和查询集中样本包括的实际标签(label)包括多种类型的信息,多种类型的信息对应多个维度的概念空间;随后,使用支撑集和查询集对目标神经网络进行至少一次迭代更新,得到更新后的目标神经网络,其中,目标神经网络中包括元模型和概念空间映射模型,至少一次迭代更新中的任意一次更新包括:将支撑集中的至少一个支撑样本作为元模型的输入,得到至少一组第一特征向量,以及将查询集中的至少一个查询样本作为元模型的输入,得到至少一组第二特征向量,元模型用于提取输入的样本的特征;通过概念空间映射模型,将至少一组第一特征向量映射至多个维度的概念空间中,得到至少一组第三特征向量,以及将至少一组第二特征向量映射至多个维度的概念空间中,得到至少一组第四特征向量;根据至少一组第三特征向量和至少一组第四特征向量,得到在多个维度的概念空间中,至少一个查询样本和至少一个支撑样本之间的距离;根据至少一个查询样本和至少一个支撑样本之间的距离,得到至少一个查询样本得到预测标签;根据至少一个查询样本的预测标签获取至少一个查询样本的损失值;根据至少一个查询样本的损失值更新目标神经网络,得到当前次迭代更新后的目标神经网络。
因此,在本申请实施方式中,可以基于支撑集和查询集包括的样本来构建概念空间,并将样本映射至每个维度的概念空间中,然后可以使用样本在该概念空间之间的距离,来训练目标神经网络。该距离可以表示样本之间的关联程度,从而使训练目标神经网络的过程中,可以基于样本之间的关联程度的关联度进行训练,从而使最终得到的目标神经网络的输出准确率更高。
在一种可能的实施方式中,目标神经网络还包括概率预测模型,概率预测模型用于计算输入的向量对应的样本与多个维度之间的关联度,上述方法还可以包括:将至少一组第一特征向量作为概率预测模型的输入,输出至少一组第一概率向量,以及将至少一组第二特征向量作为概率预测模型的输入,输出至少一组第二概率向量,概率预测模型用于计算输入的向量对应的样本与多个维度的关联度,至少一组第一概率向量和至少一组第二概率向量用于得到至少一个查询样本得到预测标签。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于华为技术有限公司,未经华为技术有限公司许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202011375881.7/2.html,转载请声明来源钻瓜专利网。