示例#1
0
    def forward(self, source, target):
        loss = 0
        source = self.sharedNet(source)
        #print ("Pass shared net")
        if (self.training == True):
            target = self.sharedNet(target)
            loss += mmd.mmd_rbf_noaccelerate(source, target)

        source = self.cls_fc(source)

        return source, loss
示例#2
0
    def forward(self, source, target):
        loss = 0
        source = self.sharedNet(source)
        if self.training == True:
            target = self.sharedNet(target)
            #loss += mmd.mmd_rbf_accelerate(source, target)
            loss += mmd.mmd_rbf_noaccelerate(source, target)

        source = self.cls_fc(source)
        #target = self.cls_fc(target)

        return source, loss
示例#3
0
 def forward(self, source, target, s_label, mu):
     source = self.sharedNet(source)
     target = self.sharedNet(target)
     source_pred, cmmd_loss = self.Inception(source, target, s_label)
     mmd_loss = 0
     if self.training:
         source = self.avgpool(source)
         source = source.view(source.size(0), -1)
         target = self.avgpool(target)
         target = target.view(target.size(0), -1)
         mmd_loss += mmd.mmd_rbf_noaccelerate(source, target)
     loss = (1 - mu) * cmmd_loss + mu * mmd_loss
     return source_pred, loss
示例#4
0
def loss_adapt_fun(X, Y, loss_tran_type):
    if loss_tran_type == 'mk_mmd':
        #mmd_loss = mmd.MMD_loss(kernel_type='rbf', kernel_mul=1, kernel_num=10)
        #loss = mmd_loss(X, Y)
        loss = mmd.mmd_rbf_noaccelerate(X,
                                        Y,
                                        kernel_mul=2.0,
                                        kernel_num=5,
                                        fix_sigma=None)
    elif loss_tran_type == 'mmd':
        #mmd_loss = mmd.MMD_loss(kernel_type='rbf', kernel_mul=1, kernel_num=1)
        #loss = mmd_loss(X, Y)
        loss = mmd.mmd_rbf_noaccelerate(X,
                                        Y,
                                        kernel_mul=2.0,
                                        kernel_num=1,
                                        fix_sigma=None)
        #loss = mmd.mmd_rbf_loss(X, Y)

    else:
        loss = 0
    return loss
示例#5
0
    def forward(self, source, target):
        loss = 0
        # 经过ResNet预训练
        source = self.sharedNet(source)
        if self.training == True:
            target = self.sharedNet(target)
            #loss += mmd.mmd_rbf_accelerate(source, target)
            # 损失函数 源域目标域的mmd距离
            loss += mmd.mmd_rbf_noaccelerate(source, target)
        # 对源域做了一个线性变换然后返回回去 做src pred(经过网络后的预测值)
        source = self.cls_fc(source)
        #target = self.cls_fc(target)

        return source, loss
示例#6
0
    def forward(self, source, target):
        loss = 0
        source = self.sharedNet(source)
        if self.training == True:
            target = self.sharedNet(target)
            # mmd_loss = mmd.mmd_rbf_accelerate(source, target)
            mmd_loss = mmd.mmd_rbf_noaccelerate(source, target)
            loss += mmd_loss
            if self.print_mmd_loss:
                print('mmd loss:', mmd_loss.data[0])
                pass

        source = self.cls_fc(source)
        #target = self.cls_fc(target)
        return source, loss
示例#7
0
    def forward(self, src, tar):
        loss = 0
        x_src = self.featureCap(src)

        x_src_mmd = x_src.view(x_src.size(0), -1)
        # 这里为了设置全连接层的输入神经元个数,故需要展示平坦层的特征数(=全连接层输入神经元个数)
        # print(x_src_mmd.size(1))

        if self.training == True:
            x_tar = self.featureCap(tar)

            x_tar_mmd = x_tar.view(x_tar.size(0), -1)
            #loss += mmd.mmd_rbf_accelerate(source, target)
            loss += mmd.mmd_rbf_noaccelerate(x_src_mmd, x_tar_mmd)

        y_src = self.fc1(x_src_mmd)
        y_src = self.fc2(y_src)
        #target = self.cls_fc(target)

        return y_src, loss
示例#8
0
    def forward(self, source, target):
        loss = 0
        img_rec_s = 0
        img_rec_t = 0
        source = self.sharedNet(source)
        if self.training == True:
            target = self.sharedNet(target)
            #loss += mmd.mmd_rbf_accelerate(source, target)
            loss += mmd.mmd_rbf_noaccelerate(source, target)

            feat_encode = self.rec_dense(source)
            feat_encode = feat_encode.view(-1, 512, 7, 7)
            img_rec_s = self.rec_feat(feat_encode)

            feat_encode = self.rec_dense(target)
            feat_encode = feat_encode.view(-1, 512, 7, 7)
            img_rec_t = self.rec_feat(feat_encode)

        source = self.cls_fc(source)
        #target = self.cls_fc(target)

        return source, loss, img_rec_s, img_rec_t