Exemplo n.º 1
0
def train_classifier(classifier, vae, datasets, dataloaders, args, optimizer_cls, scheduler_cls):
    device = args.device
    vae.eval()
    acc_per_class = np.zeros((args.num_epochs_cls,datasets['train'].num_class))
    acc = np.zeros((args.num_epochs_cls,))
    max_iter = args.num_epochs_cls * len(datasets['train']) / args.batch_size
    iter_num = 0
    for epoch in range(args.num_epochs_cls):
        classifier.train()
        #print(f'Classifier training epoch {epoch:d}/{args.num_epochs_cls:d}')
        #print(optimizer_cls.param_groups[0]['lr'])
        for iteration, (xS,xT,yS,yT) in enumerate(dataloaders['train']):
            lr_scheduler(optimizer_cls, iter_num=iter_num, max_iter=max_iter)
            iter_num += 1
            xS,xT,yS,yT = xS.to(device), xT.to(device), yS.to(device), yT.to(device)
            recon_xS,recon_xT = generate_z(xS,xT,vae,device)
            mask = yT!=-1
            xT = xT[mask,:]
            yT = yT[mask]          
            #pdb.set_trace()  
            recon_xT = recon_xT[mask,:]
            xtrain = torch.cat((xS,xT,recon_xS,recon_xT),dim=0)
            ytrain = torch.cat((yS,yT,yS,yT),dim=0)
            output = classifier(xtrain)
            #loss_cls = classifier.lossfunction(output, y)
            loss_cls = loss.CrossEntropyLabelSmooth(num_classes=10, epsilon=args.smooth)(output, ytrain)
            optimizer_cls.zero_grad()
            loss_cls.backward()
            optimizer_cls.step()
        #scheduler_cls.step()
        #test_model(classifier,datasets['test'],dataloaders['test'], device,model_type='mlp')
    return classifier
Exemplo n.º 2
0
def train_target(args):
    dset_loaders = data_load(args)

    param_group = []    
    model_resnet = network.Res50().cuda()
    for k, v in model_resnet.named_parameters():
        if k.__contains__('fc'):
            v.requires_grad = False
        else:
            param_group += [{'params': v, 'lr': args.lr}]

    optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True)

    for epoch in tqdm(range(args.max_epoch), leave=False):

        model_resnet.eval()
        mem_label = obtain_label(dset_loaders['test'], model_resnet, args)
        mem_label = torch.from_numpy(mem_label).cuda()
        model_resnet.train()

        iter_test = iter(dset_loaders['target'])
        for _, (inputs_test, _, tar_idx) in tqdm(enumerate(iter_test), leave=False):
            if inputs_test.size(0) == 1:
                continue
            inputs_test = inputs_test.cuda()

            pred = mem_label[tar_idx]
            features_test, outputs_test = model_resnet(inputs_test)

            classifier_loss = loss.CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=0)(outputs_test, pred)
            classifier_loss *= args.cls_par

            if args.ent:
                softmax_out = nn.Softmax(dim=1)(outputs_test)
                entropy_loss = torch.mean(loss.Entropy(softmax_out))
                if args.gent:
                    msoftmax = softmax_out.mean(dim=0)
                    gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
                    entropy_loss -= gentropy_loss
                classifier_loss += entropy_loss * args.ent_par

            optimizer.zero_grad()
            classifier_loss.backward()
            optimizer.step()

        model_resnet.eval()
        acc, ment = cal_acc(dset_loaders['test'], model_resnet)
        log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.dset, epoch+1, args.max_epoch, acc*100)
        args.out_file.write(log_str + '\n')
        args.out_file.flush()
        print(log_str+'\n')
    
    # torch.save(model_resnet.state_dict(), osp.join(args.output_dir, 'target.pt'))
    return model_resnet
Exemplo n.º 3
0
def train_target(args, zz=''):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    args.modelpath = args.output_dir_src + '/source_F_' + str(zz) + '.pt'
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_B_' + str(zz) + '.pt'
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_C_' + str(zz) + '.pt'
    netC.load_state_dict(torch.load(args.modelpath))
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    param_group = []
    for k, v in netF.named_parameters():
        if args.lr_decay1 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    for k, v in netB.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False
    optimizer = optim.SGD(param_group,
                          momentum=0.9,
                          weight_decay=5e-4,
                          nesterov=True)

    for epoch in tqdm(range(args.max_epoch), leave=False):
        netF.eval()
        netB.eval()
        mem_label = obtain_label(dset_loaders['test'], netF, netB, netC, args)
        mem_label = torch.from_numpy(mem_label).cuda()
        netF.train()
        netB.train()
        iter_test = iter(dset_loaders['target'])

        for _, (inputs_test, _, tar_idx) in tqdm(enumerate(iter_test),
                                                 leave=False):
            if inputs_test.size(0) == 1:
                continue
            inputs_test = inputs_test.cuda()

            pred = mem_label[tar_idx]
            features_test = netB(netF(inputs_test))
            outputs_test = netC(features_test)

            classifier_loss = loss.CrossEntropyLabelSmooth(
                num_classes=args.class_num, epsilon=0)(outputs_test, pred)
            classifier_loss *= args.cls_par

            if args.ent:
                softmax_out = nn.Softmax(dim=1)(outputs_test)
                entropy_loss = torch.mean(loss.Entropy(softmax_out))
                if args.gent:
                    msoftmax = softmax_out.mean(dim=0)
                    gentropy_loss = torch.sum(
                        -msoftmax * torch.log(msoftmax + args.epsilon))
                    entropy_loss -= gentropy_loss
                classifier_loss += entropy_loss * args.ent_par

            optimizer.zero_grad()
            classifier_loss.backward()
            optimizer.step()

        netF.eval()
        netB.eval()
        acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC)
        log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
            args.name, epoch + 1, args.max_epoch, acc * 100)
        args.out_file.write(log_str + '\n')
        args.out_file.flush()
        print(log_str + '\n')

    if args.issave:
        torch.save(
            netF.state_dict(),
            osp.join(args.output_dir, 'target_F_' + args.savename + '.pt'))
        torch.save(
            netB.state_dict(),
            osp.join(args.output_dir, 'target_B_' + args.savename + '.pt'))
        torch.save(
            netC.state_dict(),
            osp.join(args.output_dir, 'target_C_' + args.savename + '.pt'))

    return netF, netB, netC
Exemplo n.º 4
0
def train_source(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate * 10}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate * 10}]
    optimizer = optim.SGD(param_group,
                          momentum=0.9,
                          weight_decay=5e-4,
                          nesterov=True)

    acc_init = 0
    for epoch in tqdm(range(args.max_epoch), leave=False):
        netF.train()
        netB.train()
        netC.train()
        iter_source = iter(dset_loaders['source_tr'])
        for _, (inputs_source, labels_source) in tqdm(enumerate(iter_source),
                                                      leave=False):
            if inputs_source.size(0) == 1:
                continue
            inputs_source, labels_source = inputs_source.cuda(
            ), labels_source.cuda()
            outputs_source = netC(netB(netF(inputs_source)))
            classifier_loss = loss.CrossEntropyLabelSmooth(
                num_classes=args.class_num,
                epsilon=args.smooth)(outputs_source, labels_source)
            optimizer.zero_grad()
            classifier_loss.backward()
            optimizer.step()

        if (epoch + 1) % 5 == 0 and args.trte == 'full':
            netF.eval()
            netB.eval()
            netC.eval()
            acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC,
                                  args.dset == 'visda17')
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                args.name_src, epoch + 1, args.max_epoch, acc_s_te * 100)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')
            best_netF = netF.state_dict()
            best_netB = netB.state_dict()
            best_netC = netC.state_dict()
            torch.save(
                best_netF,
                osp.join(args.output_dir_src,
                         'source_F_' + str(epoch + 1) + '.pt'))
            torch.save(
                best_netB,
                osp.join(args.output_dir_src,
                         'source_B_' + str(epoch + 1) + '.pt'))
            torch.save(
                best_netC,
                osp.join(args.output_dir_src,
                         'source_C_' + str(epoch + 1) + '.pt'))

        if args.trte == 'val':
            netF.eval()
            netB.eval()
            netC.eval()
            acc_s_tr, _ = cal_acc(dset_loaders['source_tr'], netF, netB, netC)
            acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC)
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format(
                args.name_src, epoch + 1, args.max_epoch, acc_s_tr * 100,
                acc_s_te * 100)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            if acc_s_te > acc_init:
                acc_init = acc_s_te
                best_netF = netF.state_dict()
                best_netB = netB.state_dict()
                best_netC = netC.state_dict()

    torch.save(best_netF, osp.join(args.output_dir_src, 'source_F_val.pt'))
    torch.save(best_netB, osp.join(args.output_dir_src, 'source_B_val.pt'))
    torch.save(best_netC, osp.join(args.output_dir_src, 'source_C_val.pt'))
    return netF, netB, netC
Exemplo n.º 5
0
def train_source(args):
    dset_loaders = digit_load(args)
    ## set base network
    if args.dset == 'u2m':
        netF = network.LeNetBase().cuda()
    elif args.dset == 'm2u':
        netF = network.LeNetBase().cuda()
    elif args.dset == 's2m':
        netF = network.DTNBase().cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    acc_init = 0
    max_iter = args.max_epoch * len(dset_loaders["source_tr"])
    interval_iter = max_iter // 10
    iter_num = 0

    netF.train()
    netB.train()
    netC.train()

    while iter_num < max_iter:
        try:
            inputs_source, labels_source = iter_source.next()
        except:
            iter_source = iter(dset_loaders["source_tr"])
            inputs_source, labels_source = iter_source.next()

        if inputs_source.size(0) == 1:
            continue

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        inputs_source, labels_source = inputs_source.cuda(
        ), labels_source.cuda()
        outputs_source = netC(netB(netF(inputs_source)))
        classifier_loss = loss.CrossEntropyLabelSmooth(
            num_classes=args.class_num, epsilon=args.smooth)(outputs_source,
                                                             labels_source)
        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            netC.eval()
            acc_s_tr, _ = cal_acc(dset_loaders['source_tr'], netF, netB, netC)
            acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC)
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format(
                args.dset, iter_num, max_iter, acc_s_tr, acc_s_te)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            if acc_s_te >= acc_init:
                acc_init = acc_s_te
                best_netF = netF.state_dict()
                best_netB = netB.state_dict()
                best_netC = netC.state_dict()

            netF.train()
            netB.train()
            netC.train()

    torch.save(best_netF, osp.join(args.output_dir, "source_F.pt"))
    torch.save(best_netB, osp.join(args.output_dir, "source_B.pt"))
    torch.save(best_netC, osp.join(args.output_dir, "source_C.pt"))

    return netF, netB, netC
Exemplo n.º 6
0
def train_source(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate*0.1}]#1
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate*1}]#10
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate*1}]#10
    optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True)

    acc_init = 0
    netF.train()
    netB.train()
    netC.train()
    iter_num = 0
    iter_source = iter(dset_loaders["source_tr"])
    while iter_num < args.max_epoch * len(dset_loaders["source_tr"]):
        try:
            inputs_source, labels_source = iter_source.next()
        except:
            iter_source = iter(dset_loaders["source_tr"])
            inputs_source, labels_source = iter_source.next()
        if inputs_source.size(0) == 1:
            continue
        iter_num += 1
        inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
        outputs_source = netC(netB(netF(inputs_source)))
        classifier_loss = loss.CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source)
        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if (iter_num % int(args.interval*len(dset_loaders["source_tr"])) == 0):
            netF.eval()
            netB.eval()
            netC.eval()
            acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF, netB, netC, args.dset=="visda17")
            log_str = 'Task: {}, Iter:{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, acc_s_te) + '\n' + str(acc_list)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str+'\n')
            if acc_s_te >= acc_init:
                acc_init = acc_s_te
                best_netF = netF.state_dict()
                best_netB = netB.state_dict()
                best_netC = netC.state_dict()
            netF.train()
            netB.train()
            netC.train()

    torch.save(best_netF, osp.join(args.output_dir_src, "source_F_val.pt"))
    torch.save(best_netB, osp.join(args.output_dir_src, "source_B_val.pt"))
    torch.save(best_netC, osp.join(args.output_dir_src, "source_C_val.pt"))
    return netF, netB, netC
Exemplo n.º 7
0
def train(args, validate=False, label=None):
    ## set pre-process
    if validate:
        dset_loaders = data_load_y(args, label)
    else:
        dset_loaders = data_load(args)
    class_num = args.class_num
    class_weight_src = torch.ones(class_num, ).cuda()
    ##################################################################################################

    ## set base network
    if args.net == 'resnet101':
        netG = utils.ResBase101().cuda()
    elif args.net == 'resnet50':
        netG = utils.ResBase50().cuda()

    netF = utils.ResClassifier(class_num=class_num,
                               feature_dim=netG.in_features,
                               bottleneck_dim=args.bottleneck_dim).cuda()

    max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"]))
    args.max_iter = args.max_epoch * max_len

    ad_flag = False
    if args.method in {'DANN', 'DANNE'}:
        ad_net = utils.AdversarialNetwork(args.bottleneck_dim,
                                          1024,
                                          max_iter=args.max_iter).cuda()
        ad_flag = True
    if args.method in {'CDAN', 'CDANE'}:
        ad_net = utils.AdversarialNetwork(args.bottleneck_dim * class_num,
                                          1024,
                                          max_iter=args.max_iter).cuda()
        random_layer = None
        ad_flag = True

    optimizer_g = optim.SGD(netG.parameters(), lr=args.lr * 0.1)
    optimizer_f = optim.SGD(netF.parameters(), lr=args.lr)
    if ad_flag:
        optimizer_d = optim.SGD(ad_net.parameters(), lr=args.lr)

    base_network = nn.Sequential(netG, netF)

    if args.pl.startswith('atdoc_na'):
        mem_fea = torch.rand(len(dset_loaders["target"].dataset),
                             args.bottleneck_dim).cuda()
        mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True)
        mem_cls = torch.ones(len(dset_loaders["target"].dataset),
                             class_num).cuda() / class_num

    if args.pl == 'atdoc_nc':
        mem_fea = torch.rand(args.class_num, args.bottleneck_dim).cuda()
        mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True)

    source_loader_iter = iter(dset_loaders["source"])
    target_loader_iter = iter(dset_loaders["target"])

    ####
    list_acc = []
    best_ent = 100

    for iter_num in range(1, args.max_iter + 1):
        base_network.train()
        lr_scheduler(optimizer_g,
                     init_lr=args.lr * 0.1,
                     iter_num=iter_num,
                     max_iter=args.max_iter)
        lr_scheduler(optimizer_f,
                     init_lr=args.lr,
                     iter_num=iter_num,
                     max_iter=args.max_iter)
        if ad_flag:
            lr_scheduler(optimizer_d,
                         init_lr=args.lr,
                         iter_num=iter_num,
                         max_iter=args.max_iter)

        try:
            inputs_source, labels_source = source_loader_iter.next()
        except:
            source_loader_iter = iter(dset_loaders["source"])
            inputs_source, labels_source = source_loader_iter.next()
        try:
            inputs_target, _, idx = target_loader_iter.next()
        except:
            target_loader_iter = iter(dset_loaders["target"])
            inputs_target, _, idx = target_loader_iter.next()

        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()

        if args.method == 'srconly' and args.pl == 'none':
            features_source, outputs_source = base_network(inputs_source)
        else:
            features_source, outputs_source = base_network(inputs_source)
            features_target, outputs_target = base_network(inputs_target)
            features = torch.cat((features_source, features_target), dim=0)
            outputs = torch.cat((outputs_source, outputs_target), dim=0)
            softmax_out = nn.Softmax(dim=1)(outputs)

        eff = utils.calc_coeff(iter_num, max_iter=args.max_iter)
        if args.method[-1] == 'E':
            entropy = loss.Entropy(softmax_out)
        else:
            entropy = None

        if args.method in {'CDAN', 'CDANE'}:
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy,
                                      eff, random_layer)

        elif args.method in {'DANN', 'DANNE'}:
            transfer_loss = loss.DANN(features, ad_net, entropy, eff)

        elif args.method == 'DAN':
            transfer_loss = eff * loss.DAN(features_source, features_target)
        elif args.method == 'DAN_Linear':
            transfer_loss = eff * loss.DAN_Linear(features_source,
                                                  features_target)

        elif args.method == 'JAN':
            transfer_loss = eff * loss.JAN(
                [features_source, softmax_out[0:args.batch_size, :]],
                [features_target, softmax_out[args.batch_size::, :]])
        elif args.method == 'JAN_Linear':
            transfer_loss = eff * loss.JAN_Linear(
                [features_source, softmax_out[0:args.batch_size, :]],
                [features_target, softmax_out[args.batch_size::, :]])

        elif args.method == 'CORAL':
            transfer_loss = eff * loss.CORAL(features_source, features_target)
        elif args.method == 'DDC':
            transfer_loss = loss.MMD_loss()(features_source, features_target)

        elif args.method == 'srconly':
            transfer_loss = torch.tensor(0.0).cuda()
        else:
            raise ValueError('Method cannot be recognized.')

        src_ = loss.CrossEntropyLabelSmooth(reduction='none',
                                            num_classes=class_num,
                                            epsilon=args.smooth)(
                                                outputs_source, labels_source)
        weight_src = class_weight_src[labels_source].unsqueeze(0)
        classifier_loss = torch.sum(
            weight_src * src_) / (torch.sum(weight_src).item())
        total_loss = transfer_loss + classifier_loss

        eff = iter_num / args.max_iter

        if args.pl == 'none':
            pass

        elif args.pl == 'square':
            softmax_out = nn.Softmax(dim=1)(outputs_target)
            square_loss = -torch.sqrt((softmax_out**2).sum(dim=1)).mean()
            total_loss += args.tar_par * eff * square_loss

        elif args.pl == 'bsp':
            sigma_loss = bsp_loss(features)
            total_loss += args.tar_par * sigma_loss

        elif args.pl == 'bnm':
            softmax_out = nn.Softmax(dim=1)(outputs_target)
            bnm_loss = -torch.norm(softmax_out, 'nuc')
            cof = torch.tensor(
                np.sqrt(np.min(softmax_out.size())) / softmax_out.size(0))
            bnm_loss *= cof
            total_loss += args.tar_par * eff * bnm_loss

        elif args.pl == "mcc":
            softmax_out = nn.Softmax(dim=1)(outputs_target)
            ent_weight = 1 + torch.exp(-loss.Entropy(softmax_out)).detach()
            ent_weight /= ent_weight.sum()
            cov_tar = softmax_out.t().mm(
                torch.diag(softmax_out.size(0) * ent_weight)).mm(softmax_out)
            mcc_loss = (torch.diag(cov_tar) / cov_tar.sum(dim=1)).mean()
            total_loss -= args.tar_par * eff * mcc_loss

        elif args.pl == 'ent':
            softmax_out = nn.Softmax(dim=1)(outputs_target)
            ent_loss = torch.mean(loss.Entropy(softmax_out))
            ent_loss /= torch.log(torch.tensor(class_num + 0.0))
            total_loss += args.tar_par * eff * ent_loss

        elif args.pl[0:3] == 'npl':
            softmax_out = nn.Softmax(dim=1)(outputs_target)
            softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0))

            weight_, pred = torch.max(softmax_out, 1)
            loss_ = nn.CrossEntropyLoss(reduction='none')(outputs_target, pred)
            classifier_loss = torch.sum(
                weight_ * loss_) / (torch.sum(weight_).item())
            total_loss += args.tar_par * eff * classifier_loss

        elif args.pl == 'atdoc_nc':
            mem_fea_norm = mem_fea / torch.norm(
                mem_fea, p=2, dim=1, keepdim=True)
            dis = torch.mm(features_target.detach(), mem_fea_norm.t())
            _, pred = torch.max(dis, dim=1)
            classifier_loss = nn.CrossEntropyLoss()(outputs_target, pred)
            total_loss += args.tar_par * eff * classifier_loss

        elif args.pl.startswith('atdoc_na'):

            dis = -torch.mm(features_target.detach(), mem_fea.t())
            for di in range(dis.size(0)):
                dis[di, idx[di]] = torch.max(dis)
            _, p1 = torch.sort(dis, dim=1)

            w = torch.zeros(features_target.size(0), mem_fea.size(0)).cuda()
            for wi in range(w.size(0)):
                for wj in range(args.K):
                    w[wi][p1[wi, wj]] = 1 / args.K

            weight_, pred = torch.max(w.mm(mem_cls), 1)

            if args.pl == 'atdoc_na_now':
                classifier_loss = nn.CrossEntropyLoss()(outputs_target, pred)
            else:
                loss_ = nn.CrossEntropyLoss(reduction='none')(outputs_target,
                                                              pred)
                classifier_loss = torch.sum(
                    weight_ * loss_) / (torch.sum(weight_).item())
            total_loss += args.tar_par * eff * classifier_loss

        optimizer_g.zero_grad()
        optimizer_f.zero_grad()
        if ad_flag:
            optimizer_d.zero_grad()
        total_loss.backward()
        optimizer_g.step()
        optimizer_f.step()
        if ad_flag:
            optimizer_d.step()

        if args.pl.startswith('atdoc_na'):
            base_network.eval()
            with torch.no_grad():
                features_target, outputs_target = base_network(inputs_target)
                features_target = features_target / torch.norm(
                    features_target, p=2, dim=1, keepdim=True)
                softmax_out = nn.Softmax(dim=1)(outputs_target)
                if args.pl == 'atdoc_na_nos':
                    outputs_target = softmax_out
                else:
                    outputs_target = softmax_out**2 / (
                        (softmax_out**2).sum(dim=0))

            mem_fea[idx] = (1.0 - args.momentum) * mem_fea[
                idx] + args.momentum * features_target.clone()
            mem_cls[idx] = (1.0 - args.momentum) * mem_cls[
                idx] + args.momentum * outputs_target.clone()

        if args.pl == 'atdoc_nc':
            base_network.eval()
            with torch.no_grad():
                features_target, outputs_target = base_network(inputs_target)
                softmax_t = nn.Softmax(dim=1)(outputs_target)
                _, pred_t = torch.max(softmax_t, 1)
                onehot_t = torch.eye(args.class_num)[pred_t].cuda()
                center_t = torch.mm(features_target.t(),
                                    onehot_t) / (onehot_t.sum(dim=0) + 1e-8)

            mem_fea = (1.0 - args.momentum
                       ) * mem_fea + args.momentum * center_t.t().clone()

        if iter_num % int(args.eval_epoch * max_len) == 0:
            base_network.eval()
            if args.dset == 'VISDA-C':
                acc, py, score, y, tacc = utils.cal_acc_visda(
                    dset_loaders["test"], base_network)
                args.out_file.write(tacc + '\n')
                args.out_file.flush()

                _ent = loss.Entropy(score)
                mean_ent = 0
                for ci in range(args.class_num):
                    mean_ent += _ent[py == ci].mean()
                mean_ent /= args.class_num

            else:
                acc, py, score, y = utils.cal_acc(dset_loaders["test"],
                                                  base_network)
                mean_ent = torch.mean(loss.Entropy(score))

            list_acc.append(acc * 100)
            if best_ent > mean_ent:
                best_ent = mean_ent
                val_acc = acc * 100
                best_y = y
                best_py = py
                best_score = score

            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(
                args.name, iter_num, args.max_iter, acc * 100, mean_ent)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

    idx = np.argmax(np.array(list_acc))
    max_acc = list_acc[idx]
    final_acc = list_acc[-1]

    log_str = '\n==========================================\n'
    log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(
        val_acc, max_acc, final_acc)
    args.out_file.write(log_str + '\n')
    args.out_file.flush()

    # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt"))
    # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(),
    #     'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()})

    return best_y.cpu().numpy().astype(np.int64)
Exemplo n.º 8
0
def train(args):
    ## set pre-process
    dset_loaders = data_load(args)
    class_num = args.class_num
    class_weight_src = torch.ones(class_num, ).cuda()
    ##################################################################################################

    ## set base network
    if args.net == 'resnet34':
        netG = utils.ResBase34().cuda()
    elif args.net == 'vgg16':
        netG = utils.VGG16Base().cuda()

    netF = utils.ResClassifier(class_num=class_num,
                               feature_dim=netG.in_features,
                               bottleneck_dim=args.bottleneck_dim).cuda()

    max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"]))
    args.max_iter = args.max_epoch * max_len

    ad_flag = False
    if args.method == 'DANN':
        ad_net = utils.AdversarialNetwork(args.bottleneck_dim,
                                          1024,
                                          max_iter=args.max_iter).cuda()
        ad_flag = True
    if args.method == 'CDANE':
        ad_net = utils.AdversarialNetwork(args.bottleneck_dim * class_num,
                                          1024,
                                          max_iter=args.max_iter).cuda()
        random_layer = None
        ad_flag = True

    optimizer_g = optim.SGD(netG.parameters(), lr=args.lr * 0.1)
    optimizer_f = optim.SGD(netF.parameters(), lr=args.lr)
    if ad_flag:
        optimizer_d = optim.SGD(ad_net.parameters(), lr=args.lr)

    base_network = nn.Sequential(netG, netF)

    if args.pl.startswith('atdoc_na'):
        mem_fea = torch.rand(
            len(dset_loaders["target"].dataset) +
            len(dset_loaders["ltarget"].dataset), args.bottleneck_dim).cuda()
        mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True)
        mem_cls = torch.ones(
            len(dset_loaders["target"].dataset) +
            len(dset_loaders["ltarget"].dataset), class_num).cuda() / class_num

    if args.pl == 'atdoc_nc':
        mem_fea = torch.rand(args.class_num, args.bottleneck_dim).cuda()
        mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True)

    source_loader_iter = iter(dset_loaders["source"])
    target_loader_iter = iter(dset_loaders["target"])
    ltarget_loader_iter = iter(dset_loaders["ltarget"])

    # ###
    list_acc = []
    best_val_acc = 0

    for iter_num in range(1, args.max_iter + 1):
        # print(iter_num)
        base_network.train()
        lr_scheduler(optimizer_g,
                     init_lr=args.lr * 0.1,
                     iter_num=iter_num,
                     max_iter=args.max_iter)
        lr_scheduler(optimizer_f,
                     init_lr=args.lr,
                     iter_num=iter_num,
                     max_iter=args.max_iter)
        if ad_flag:
            lr_scheduler(optimizer_d,
                         init_lr=args.lr,
                         iter_num=iter_num,
                         max_iter=args.max_iter)

        try:
            inputs_source, labels_source = source_loader_iter.next()
        except:
            source_loader_iter = iter(dset_loaders["source"])
            inputs_source, labels_source = source_loader_iter.next()
        try:
            inputs_target, _, idx = target_loader_iter.next()
        except:
            target_loader_iter = iter(dset_loaders["target"])
            inputs_target, _, idx = target_loader_iter.next()

        try:
            inputs_ltarget, labels_ltarget, lidx = ltarget_loader_iter.next()
        except:
            ltarget_loader_iter = iter(dset_loaders["ltarget"])
            inputs_ltarget, labels_ltarget, lidx = ltarget_loader_iter.next()

        inputs_ltarget, labels_ltarget = inputs_ltarget.cuda(
        ), labels_ltarget.cuda()

        inputs_source, inputs_target, labels_source = inputs_source.cuda(
        ), inputs_target.cuda(), labels_source.cuda()

        if args.method == 'srconly' and args.pl == 'none':
            features_source, outputs_source = base_network(inputs_source)
            features_ltarget, outputs_ltarget = base_network(inputs_ltarget)
        else:
            features_ltarget, outputs_ltarget = base_network(inputs_ltarget)
            features_source, outputs_source = base_network(inputs_source)
            features_target, outputs_target = base_network(inputs_target)

            features_target = torch.cat((features_ltarget, features_target),
                                        dim=0)
            outputs_target = torch.cat((outputs_ltarget, outputs_target),
                                       dim=0)

            features = torch.cat((features_source, features_target), dim=0)
            outputs = torch.cat((outputs_source, outputs_target), dim=0)
            softmax_out = nn.Softmax(dim=1)(outputs)

        eff = utils.calc_coeff(iter_num, max_iter=args.max_iter)

        if args.method[-1] == 'E':
            entropy = loss.Entropy(softmax_out)
        else:
            entropy = None

        if args.method == 'CDANE':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy,
                                      eff, random_layer)

        elif args.method == 'DANN':
            transfer_loss = loss.DANN(features, ad_net, entropy, eff)

        elif args.method == 'srconly':
            transfer_loss = torch.tensor(0.0).cuda()
        else:
            raise ValueError('Method cannot be recognized.')

        src_ = loss.CrossEntropyLabelSmooth(reduction='none',
                                            num_classes=class_num,
                                            epsilon=args.smooth)(
                                                outputs_source, labels_source)
        weight_src = class_weight_src[labels_source].unsqueeze(0)
        classifier_loss = torch.sum(
            weight_src * src_) / (torch.sum(weight_src).item())
        total_loss = transfer_loss + classifier_loss

        ltar_ = loss.CrossEntropyLabelSmooth(reduction='none',
                                             num_classes=class_num,
                                             epsilon=args.smooth)(
                                                 outputs_ltarget,
                                                 labels_ltarget)
        weight_src = class_weight_src[labels_ltarget].unsqueeze(0)
        ltar_classifier_loss = torch.sum(
            weight_src * ltar_) / (torch.sum(weight_src).item())
        total_loss += ltar_classifier_loss

        eff = iter_num / args.max_iter

        if not args.pl == 'none':
            outputs_target = outputs_target[-args.batch_size // 3:, :]
            features_target = features_target[-args.batch_size // 3:, :]

        if args.pl == 'none':
            pass

        elif args.pl == 'square':
            softmax_out = nn.Softmax(dim=1)(outputs_target)
            square_loss = -torch.sqrt((softmax_out**2).sum(dim=1)).mean()
            total_loss += args.tar_par * eff * square_loss

        elif args.pl == 'bsp':
            sigma_loss = bsp_loss(features)
            total_loss += args.tar_par * sigma_loss

        elif args.pl == 'ent':
            softmax_out = nn.Softmax(dim=1)(outputs_target)
            ent_loss = torch.mean(loss.Entropy(softmax_out))
            ent_loss /= torch.log(torch.tensor(class_num + 0.0))
            total_loss += args.tar_par * eff * ent_loss

        elif args.pl == 'bnm':
            softmax_out = nn.Softmax(dim=1)(outputs_target)
            bnm_loss = -torch.norm(softmax_out, 'nuc')
            cof = torch.tensor(
                np.sqrt(np.min(softmax_out.size())) / softmax_out.size(0))
            bnm_loss *= cof
            total_loss += args.tar_par * eff * bnm_loss

        elif args.pl == 'mcc':
            softmax_out = nn.Softmax(dim=1)(outputs_target)
            ent_weight = 1 + torch.exp(-loss.Entropy(softmax_out)).detach()
            ent_weight /= ent_weight.sum()
            cov_tar = softmax_out.t().mm(
                torch.diag(softmax_out.size(0) * ent_weight)).mm(softmax_out)
            mcc_loss = (torch.diag(cov_tar) / cov_tar.sum(dim=1)).mean()
            total_loss -= args.tar_par * eff * mcc_loss

        elif args.pl == 'npl':
            softmax_out = nn.Softmax(dim=1)(outputs_target)
            softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0))

            weight_, pred = torch.max(softmax_out, 1)
            loss_ = nn.CrossEntropyLoss(reduction='none')(outputs_target, pred)
            classifier_loss = torch.sum(
                weight_ * loss_) / (torch.sum(weight_).item())
            total_loss += args.tar_par * eff * classifier_loss

        elif args.pl == 'atdoc_nc':
            mem_fea_norm = mem_fea / torch.norm(
                mem_fea, p=2, dim=1, keepdim=True)
            dis = torch.mm(features_target.detach(), mem_fea_norm.t())
            _, pred = torch.max(dis, dim=1)
            classifier_loss = nn.CrossEntropyLoss()(outputs_target, pred)
            total_loss += args.tar_par * eff * classifier_loss

        elif args.pl.startswith('atdoc_na'):

            dis = -torch.mm(features_target.detach(), mem_fea.t())
            for di in range(dis.size(0)):
                dis[di, idx[di]] = torch.max(dis)
            _, p1 = torch.sort(dis, dim=1)

            w = torch.zeros(features_target.size(0), mem_fea.size(0)).cuda()
            for wi in range(w.size(0)):
                for wj in range(args.K):
                    w[wi][p1[wi, wj]] = 1 / args.K

            weight_, pred = torch.max(w.mm(mem_cls), 1)

            if args.pl.startswith('atdoc_na_now'):
                classifier_loss = nn.CrossEntropyLoss()(outputs_target, pred)
            else:
                loss_ = nn.CrossEntropyLoss(reduction='none')(outputs_target,
                                                              pred)
                classifier_loss = torch.sum(
                    weight_ * loss_) / (torch.sum(weight_).item())
            total_loss += args.tar_par * eff * classifier_loss

        optimizer_g.zero_grad()
        optimizer_f.zero_grad()
        if ad_flag:
            optimizer_d.zero_grad()
        total_loss.backward()
        optimizer_g.step()
        optimizer_f.step()
        if ad_flag:
            optimizer_d.step()

        if args.pl.startswith('atdoc_na'):
            base_network.eval()
            with torch.no_grad():
                features_target, outputs_target = base_network(inputs_target)
                features_target = features_target / torch.norm(
                    features_target, p=2, dim=1, keepdim=True)
                softmax_out = nn.Softmax(dim=1)(outputs_target)
                if args.pl.startswith('atdoc_na_nos'):
                    outputs_target = softmax_out
                else:
                    outputs_target = softmax_out**2 / (
                        (softmax_out**2).sum(dim=0))

            mem_fea[idx] = (1.0 - args.momentum) * mem_fea[
                idx] + args.momentum * features_target.clone()
            mem_cls[idx] = (1.0 - args.momentum) * mem_cls[
                idx] + args.momentum * outputs_target.clone()

            with torch.no_grad():
                features_ltarget, outputs_ltarget = base_network(
                    inputs_ltarget)
                features_ltarget = features_ltarget / torch.norm(
                    features_ltarget, p=2, dim=1, keepdim=True)
                softmax_out = nn.Softmax(dim=1)(outputs_ltarget)
                if args.pl.startswith('atdoc_na_nos'):
                    outputs_ltarget = softmax_out
                else:
                    outputs_ltarget = softmax_out**2 / (
                        (softmax_out**2).sum(dim=0))

            mem_fea[lidx + len(dset_loaders["target"].dataset)] = (1.0 - args.momentum) * \
                mem_fea[lidx + len(dset_loaders["target"].dataset)] + args.momentum * features_ltarget.clone()
            mem_cls[lidx + len(dset_loaders["target"].dataset)] = (1.0 - args.momentum) * \
                mem_cls[lidx + len(dset_loaders["target"].dataset)] + args.momentum * outputs_ltarget.clone()

        if args.pl == 'atdoc_nc':
            base_network.eval()
            with torch.no_grad():
                feat_u, outputs_target = base_network(inputs_target)
                softmax_t = nn.Softmax(dim=1)(outputs_target)
                _, pred_t = torch.max(softmax_t, 1)
                onehot_tu = torch.eye(args.class_num)[pred_t].cuda()

                feat_l, outputs_target = base_network(inputs_ltarget)
                softmax_t = nn.Softmax(dim=1)(outputs_target)
                _, pred_t = torch.max(softmax_t, 1)
                onehot_tl = torch.eye(args.class_num)[pred_t].cuda()

            center_t = ((torch.mm(feat_u.t(), onehot_tu) + torch.mm(
                feat_l.t(), onehot_tl))) / (onehot_tu.sum(dim=0) +
                                            onehot_tl.sum(dim=0) + 1e-8)
            mem_fea = (1.0 - args.momentum
                       ) * mem_fea + args.momentum * center_t.t().clone()

        if iter_num % int(args.eval_epoch * max_len) == 0:
            base_network.eval()
            acc, py, score, y = utils.cal_acc(dset_loaders["test"],
                                              base_network)
            val_acc, _, _, _ = utils.cal_acc(dset_loaders["val"], base_network)

            list_acc.append(acc * 100)
            if best_val_acc <= val_acc:
                best_val_acc = val_acc
                best_acc = acc
                best_y = y
                best_py = py
                best_score = score

            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Val Acc = {:.2f}%'.format(
                args.name, iter_num, args.max_iter, acc * 100, val_acc * 100)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

    val_acc = best_acc * 100
    idx = np.argmax(np.array(list_acc))
    max_acc = list_acc[idx]
    final_acc = list_acc[-1]

    log_str = '\n==========================================\n'
    log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(
        val_acc, max_acc, final_acc)
    args.out_file.write(log_str + '\n')
    args.out_file.flush()
Exemplo n.º 9
0
def train_target(args, zz=''):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()

    args.modelpath = args.output_dir_src + '/source_F_' + str(zz) + '.pt'   
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_B_' + str(zz) + '.pt'   
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_C_' + str(zz) + '.pt'    
    netC.load_state_dict(torch.load(args.modelpath))
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    param_group = []
    for k, v in netF.named_parameters():
        if args.lr_decay1 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    for k, v in netB.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False
    optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    netF.train()
    netB.train()

    iter_num = 0
    iter_target = iter(dset_loaders["target"])
    while iter_num < args.max_epoch * len(dset_loaders["target"]):
        try:
            inputs_test, _, tar_idx = iter_target.next()
        except:
            iter_target = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = iter_target.next()
        if inputs_test.size(0) == 1:
            continue

        if iter_num % int(args.interval*len(dset_loaders["target"])) == 0:
            netF.eval()
            netB.eval()
            mem_label = obtain_label(dset_loaders['test'], netF, netB, netC, args)
            mem_label = torch.from_numpy(mem_label).cuda()
            netF.train()
            netB.train()

        iter_num += 1
        inputs_test = inputs_test.cuda()
        
        pred = mem_label[tar_idx]
        
        features_test = netB(netF(inputs_test))
        outputs_test = netC(features_test)
        classifier_loss = loss.CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=0)(outputs_test, pred)
        classifier_loss *= args.cls_par

        if args.ent:
            softmax_out = nn.Softmax(dim=1)(outputs_test)
            entropy_loss = torch.mean(loss.Entropy(softmax_out))
            if args.gent:
                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
                entropy_loss -= gentropy_loss
            classifier_loss += entropy_loss * args.ent_par

        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()
    
        if iter_num % int(args.interval*len(dset_loaders["target"])) == 0:
            netF.eval()
            netB.eval()
            acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, args.dset=="visda17")
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, \
                args.max_epoch * len(dset_loaders["target"]), acc) + '\n' + acc_list

            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str+'\n')

            netF.train()
    		netB.train()