[发明专利]基于混合损失与图注意力的小样本SAR目标分类方法有效
申请号: | 202110408623.2 | 申请日: | 2021-04-16 |
公开(公告)号: | CN113095416B | 公开(公告)日: | 2023-08-18 |
发明(设计)人: | 白雪茹;杨敏佳;孟昭晗;周峰 | 申请(专利权)人: | 西安电子科技大学 |
主分类号: | G06V10/764 | 分类号: | G06V10/764;G06V10/774;G06V10/82;G06N3/0464;G06N3/08;G06N3/045;G06N3/042 |
代理公司: | 陕西电子工业专利中心 61205 | 代理人: | 陈宏社;王品华 |
地址: | 710071*** | 国省代码: | 陕西;61 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 基于 混合 损失 注意力 样本 sar 目标 分类 方法 | ||
1.一种基于混合损失与图注意力的小样本SAR目标分类方法,其特征在于,包括如下步骤:
(1)获取训练样本集和测试样本集
(1a)获取包含C个不同目标类别的多幅合成孔径雷达SAR图像,每个目标类别对应M幅大小为h×h的SAR图像,每幅SAR图像包含1个目标,其中C≥10,M≥200,h=128;
(1b)对每幅SAR图像中的目标类别进行标记,并将随机选取的包含Ctrain个目标类别的总共Ctrain×M幅SAR图像及每幅SAR图像的标签作为训练样本集将其余Ctest个目标类别的总共Ctest×M幅SAR图像及每幅SAR图像的标签作为测试样本集其中Ctrain+Ctest=C,3≤Ctest≤5;
(2)构建基于混合损失与图注意力的网络模型H:
构建包含顺次级联的数据增强模块D、嵌入网络模块E、节点特征初始化模块I和图注意力网络模块G的基于混合损失与图注意力的网络模型H,其中,嵌入网络模块E包括顺次级联的多个第一卷积模块EC和一个第二卷积模块EL,每个第一卷积模块EC包括依次层叠的第一卷积层、第一批归一化层、Mish激活层和最大池化层,第二卷积模块EL包括依次层叠的第二卷积层和第二批归一化层;图注意力网络模块G包括顺次级联的多个图更新层U,每个图更新层U包括依次层叠的边特征构建模块UE、注意力权重计算模块UW和节点特征更新模块UN,每个边特征构建模块UE包括多个依次层叠的第一全连接层,每个节点特征更新模块UN包括一个第二全连接层;
(3)对基于混合损失与图注意力的网络模型H进行迭代训练:
(3a)初始化迭代次数为n,最大迭代次数为N,N≥1000,并令n=0;
(3b)从训练样本集随机选取包含Ctest个目标类别总共Ctest×M幅SAR图像,并对每幅SAR图像的标签进行one-hot编码,得到每幅SAR图像的Ctest维标签向量,然后将从Ctest×M幅SAR图像中随机选取每个目标类别包含的K幅SAR图像及对应的标签向量作为训练支撑样本集将剩余的Ctest(M-K)幅SAR图像及对应的标签向量作为训练查询样本集其中,one-hot编码后每幅SAR图像的标签向量中第c维的元素表示该SAR图像中的目标属于Ctest个目标类别中第c个目标类别的概率,表示由SAR图像及其对应的标签向量组成的第a个训练支撑样本,表示由SAR图像及其对应的标签向量组成的第b个训练查询样本,1≤K≤10;
(3c)将训练支撑样本集与每个训练查询样本组合成训练任务得到训练任务集并将作为基于混合损失与图注意力的网络模型H的输入进行前向传播:
(3c1)数据增强模块D对训练任务集中的每幅SAR图像进行数据增强:对每幅SAR图像进行幂次变换,并对幂次变换后的SAR图像添加噪声,再对添加噪声后的SAR图像进行翻转变换,然后对翻转变换后的SAR图像进行旋转变换,得到增强训练任务集
其中,表示训练任务对应的增强训练任务,表示训练任务中的训练支撑样本对应的增强训练支撑样本,表示训练查询样本对应的增强训练查询样本;
(3c2)嵌入网络模块E对增强训练任务集中的每个增强训练任务包含的每幅SAR图像进行映射,得到训练嵌入向量组集合并采用嵌入损失函数LE,通过训练嵌入向量组集合计算训练任务集的嵌入损失值lE:
其中,表示增强训练任务对应的训练嵌入向量组,满足a≠CtestK+1的表示增强训练支撑样本对应的训练嵌入向量,表示增强训练查询样本对应的训练嵌入向量,log(·)表示以自然常数e为底的对数,exp(·)表示以自然常数e为底的指数,∑表示连续求和,表示对训练任务中的训练支撑样本集包括的第c个目标类别的每幅SAR图像对应的每个训练嵌入向量求均值得到的第c个目标类别的类中心,表示和训练任务中的训练查询样本包含的SAR图像中的目标属于同一个目标类别的类中心,d表示度量函数,d(p,q)=||p-q||2;
(3c3)节点特征初始化模块I构造一个虚拟标签向量并对每个训练嵌入向量组中满足a≠CtestK+1的每个训练嵌入向量与对应的SAR图像的标签向量进行拼接,同时对每个训练嵌入向量组中的训练嵌入向量与虚拟标签向量进行拼接,得到训练节点1层特征组集合
其中,表示Ctest维每个维度的元素值全为1的向量,表示训练嵌入向量组对应的训练节点1层特征组,表示训练嵌入向量对应的训练节点1层特征;
(3c4)图注意力网络模块G通过输入训练节点1层特征组集合对包括的每个训练节点1层特征组中的训练节点1层特征对应的训练查询样本包括的SAR图像中的目标进行类别预测,得到训练预测结果向量集合其中,表示训练节点1层特征对应的维数为Ctest的训练预测结果向量,第c维的元素表示训练节点1层特征对应的训练查询样本包括的SAR图像中的目标属于第c个目标类别的预测概率;
(3c5)采用分类损失函数LC,通过训练预测结果向量集合和训练查询样本集中的所有的标签向量计算训练任务集的分类损失值lC:
其中,为训练预测结果向量中第c维元素的值,yb,c为训练预测结果向量对应的SAR图像的标签向量中第c维元素的值;
(3d)对训练任务集的分类损失值lC和训练任务集的嵌入损失值lE求加权和,得到训练任务集的混合损失值l,l=λlC+(1-λ)lE,然后利用随机梯度下降算法,通过混合损失值l对嵌入网络模块E中所有第一卷积层和第二卷积层的参数、图注意力网络模块G中所有第一全连接层和第二全连接层的参数进行更新,其中,λ为权重,0.7≤λ<1;
(3e)判断n≥N是否成立,若是,得到训练好的基于混合损失与图注意力的网络模型H′,否则,令n=n+1,并执行步骤(3b);
(4)获取小样本SAR图像的目标分类结果:
(4a)对测试样本集包含的每幅SAR图像的标签进行one-hot编码,得到每幅SAR图像的Ctest维标签向量,然后从测试样本集的Ctest×M幅SAR图像中随机选取每个目标类别包含的K幅SAR图像及对应的标签向量作为测试支撑样本集将剩余的Ctest(M-K)幅SAR图像及对应的标签向量作为测试查询样本集其中,表示由SAR图像及其对应的标签向量组成的第e个测试支撑样本,表示由SAR图像及其对应的标签向量组成的第g个测试查询样本;
(4b)将测试支撑样本集与每个测试查询样本组合成测试任务得到测试任务集并将作为训练好的基于混合损失与图注意力的网络模型H′的输入进行前向传播:
(4b1)训练好的嵌入网络模块E′对测试任务集中的每个测试任务包含的每幅SAR图像进行映射,得到测试嵌入向量组集合
其中,表示测试任务对应的测试嵌入向量组,满足e≠CtestK+1的表示测试支撑样本对应的测试嵌入向量,表示测试查询样本对应的测试嵌入向量;
(4b2)节点特征初始化模块I构造一个虚拟标签向量并对每个测试嵌入向量组中满足e≠CtestK+1的每个测试嵌入向量与对应的SAR图像的标签向量进行拼接,同时对每个测试嵌入向量组中的测试嵌入向量与虚拟标签向量进行拼接,得到测试节点1层特征组集合
其中,表示测试嵌入向量组对应的测试点1层特征组,表测试嵌入向量对应的测试节点1层特征;
(4b3)训练好的图注意力网络模块G′通过输入测试节点1层特征组集合对包括的每个测试节点1层特征组中的测试节点1层特征对应的测试查询样本包括的SAR图像中的目标进行类别预测,得到测试预测结果向量集合每个测试预测结果向量中最大值对应的维数号即为对应的测试查询样本包括的SAR图像中目标的预测类别,其中,表示测试节点1层特征对应的维数为Ctest的测试预测结果向量,第c维的元素值表示测试节点1层特征对应的测试查询样本包括的SAR图像中的目标属于第c个目标类别的概率。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于西安电子科技大学,未经西安电子科技大学许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202110408623.2/1.html,转载请声明来源钻瓜专利网。