示例#1
0
文件: train.py 项目: yiyang-wang/AFN
            source_loader_iter = iter(source_loader)
            s_imgs, s_labels = source_loader_iter.next()

        if s_imgs.size(0) != args.batch_size or t_imgs.size(0) != args.batch_size:
            continue

        s_imgs = Variable(s_imgs.cuda())
        s_labels = Variable(s_labels.cuda())     
        t_imgs = Variable(t_imgs.cuda())
        
        opt_g.zero_grad()
        opt_f.zero_grad()

        s_bottleneck = netG(s_imgs)
        t_bottleneck = netG(t_imgs)
        s_fc2_emb, s_logit = netF(s_bottleneck)
        t_fc2_emb, t_logit = netF(t_bottleneck)

        s_cls_loss = get_cls_loss(s_logit, s_labels)
        s_fc2_L2norm_loss = get_L2norm_loss_self_driven(s_fc2_emb)
        t_fc2_L2norm_loss = get_L2norm_loss_self_driven(t_fc2_emb)

        loss = s_cls_loss + s_fc2_L2norm_loss + t_fc2_L2norm_loss
        loss.backward()

        opt_g.step()
        opt_f.step()
    if epoch % 10 == 0:   
        torch.save(netG.state_dict(), os.path.join(args.snapshot, "Office31_HAFN_" + args.task + "_netG_" + args.post + "." + args.repeat + "_" + str(epoch) + ".pth"))
        torch.save(netF.state_dict(), os.path.join(args.snapshot, "Office31_HAFN_" + args.task + "_netF_" + args.post + "." + args.repeat + "_" + str(epoch) + ".pth"))
示例#2
0
        opt_g.zero_grad()
        opt_f.zero_grad()

        s_bottleneck = netG(s_imgs)
        t_bottleneck = netG(t_imgs)
        s_fc2_emb, s_logit = netF(s_bottleneck)
        t_fc2_emb, t_logit = netF(t_bottleneck)

        s_cls_loss = get_cls_loss(s_logit, s_labels)
        s_fc2_ring_loss = get_L2norm_loss_self_driven(s_fc2_emb)
        t_fc2_ring_loss = get_L2norm_loss_self_driven(t_fc2_emb)

        loss = s_cls_loss + s_fc2_ring_loss + t_fc2_ring_loss
        loss.backward()

        opt_g.step()
        opt_f.step()

    if epoch % 10 == 0:
        torch.save(
            netG.state_dict(),
            os.path.join(
                args.snapshot, "VisDA_HAFN_" + args.model + "_netG_" +
                args.post + '.' + args.repeat + '_' + str(epoch) + ".pth"))
        torch.save(
            netF.state_dict(),
            os.path.join(
                args.snapshot, "VisDA_HAFN_" + args.model + "_netF_" +
                args.post + '.' + args.repeat + '_' + str(epoch) + ".pth"))
示例#3
0
文件: train.py 项目: redhat12345/AFN
    target_loader_iter = iter(target_loader)
    print('>>training epoch : ' + str(epoch))
    
    for i, (t_imgs, _) in tqdm.tqdm(enumerate(target_loader_iter)):
        s_imgs, s_labels = source_loader_iter.next()
        s_imgs = Variable(s_imgs.cuda())
        s_labels = Variable(s_labels.cuda())
        t_imgs = Variable(t_imgs.cuda())
        
        opt_g.zero_grad()
        opt_f.zero_grad()

        s_bottleneck = netG(s_imgs)
        t_bottleneck = netG(t_imgs)
        s_fc2_emb, s_logit = netF(s_bottleneck)
        t_fc2_emb, _ = netF(t_bottleneck)

        s_cls_loss = get_cls_loss(s_logit, s_labels)
        s_fc2_L2norm_loss = get_L2norm_loss_self_driven(s_fc2_emb)
        t_fc2_L2norm_loss = get_L2norm_loss_self_driven(t_fc2_emb)

        loss = s_cls_loss + s_fc2_L2norm_loss + t_fc2_L2norm_loss
        loss.backward()

        opt_g.step()
        opt_f.step()
        
    if epoch % 10 == 0:   
        torch.save(netG.state_dict(), os.path.join(args.snapshot, "VisDA_IAFN_"+ args.model + "_netG_" + args.post + '.' + args.repeat + '_' + str(epoch) + ".pth"))
        torch.save(netF.state_dict(), os.path.join(args.snapshot, "VisDA_IAFN_"+ args.model + "_netF_" + args.post + '.' + args.repeat + '_' + str(epoch) + ".pth"))