예제 #1
0
    def train(class_dist_threshold_list):
        G.train()
        F1.train()
        optimizer_g = optim.SGD(params,
                                momentum=0.9,
                                weight_decay=0.0005,
                                nesterov=True)
        optimizer_f = optim.SGD(list(F1.parameters()),
                                lr=1.0,
                                momentum=0.9,
                                weight_decay=0.0005,
                                nesterov=True)

        def zero_grad_all():
            optimizer_g.zero_grad()
            optimizer_f.zero_grad()

        param_lr_g = []
        for param_group in optimizer_g.param_groups:
            param_lr_g.append(param_group["lr"])
        param_lr_f = []
        for param_group in optimizer_f.param_groups:
            param_lr_f.append(param_group["lr"])

        # Setting the loss function to be used for the classification loss
        if args.loss == 'CE':
            criterion = nn.CrossEntropyLoss().to(device)
        if args.loss == 'FL':
            criterion = FocalLoss(alpha=1, gamma=args.gamma).to(device)
        if args.loss == 'CBFL':
            # Calculating the list having the number of examples per class which is going to be used in the CB focal loss
            beta = args.beta
            effective_num = 1.0 - np.power(beta, class_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                class_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)
            criterion = CBFocalLoss(weight=per_cls_weights,
                                    gamma=args.gamma).to(device)

        all_step = args.steps
        data_iter_s = iter(source_loader)
        data_iter_t = iter(target_loader)
        data_iter_t_unl = iter(target_loader_unl)
        len_train_source = len(source_loader)
        len_train_target = len(target_loader)
        len_train_target_semi = len(target_loader_unl)
        best_acc = 0
        counter = 0
        """
        x = torch.load("./freezed_models/alexnet_p2r.ckpt.best.pth.tar")
        G.load_state_dict(x['G_state_dict'])
        F1.load_state_dict(x['F1_state_dict'])
        optimizer_f.load_state_dict(x['optimizer_f'])
        optimizer_g.load_state_dict(x['optimizer_g'])
        """
        reg_weight = args.reg
        for step in range(all_step):
            optimizer_g = inv_lr_scheduler(param_lr_g,
                                           optimizer_g,
                                           step,
                                           init_lr=args.lr)
            optimizer_f = inv_lr_scheduler(param_lr_f,
                                           optimizer_f,
                                           step,
                                           init_lr=args.lr)
            lr = optimizer_f.param_groups[0]['lr']
            # condition for restarting the iteration for each of the data loaders
            if step % len_train_target == 0:
                data_iter_t = iter(target_loader)
            if step % len_train_target_semi == 0:
                data_iter_t_unl = iter(target_loader_unl)
            if step % len_train_source == 0:
                data_iter_s = iter(source_loader)
            data_t = next(data_iter_t)
            data_t_unl = next(data_iter_t_unl)
            data_s = next(data_iter_s)
            with torch.no_grad():
                im_data_s.resize_(data_s[0].size()).copy_(data_s[0])
                gt_labels_s.resize_(data_s[1].size()).copy_(data_s[1])
                im_data_t.resize_(data_t[0].size()).copy_(data_t[0])
                gt_labels_t.resize_(data_t[1].size()).copy_(data_t[1])
                im_data_tu.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])

            zero_grad_all()
            if args.uda == 1:
                data = im_data_s
                target = gt_labels_s
            else:
                data = torch.cat((im_data_s, im_data_t), 0)
                target = torch.cat((gt_labels_s, gt_labels_t), 0)
            #print(data.shape)
            output = G(data)
            out1 = F1(output)
            if args.attribute is not None:
                if args.net == 'resnet34':
                    reg_loss = regularizer(F1.fc3.weight, att)
                    loss = criterion(out1, target) + reg_weight * reg_loss
                else:
                    reg_loss = regularizer(F1.fc2.weight, att)
                    loss = criterion(out1, target) + reg_weight * reg_loss
            else:
                reg_loss = torch.tensor(0)
                loss = criterion(out1, target)

            if args.attribute is not None:
                if step % args.save_interval == 0 and step != 0:
                    reg_weight = 0.5 * reg_weight
                    print("Reduced Reg weight to: ", reg_weight)

            loss.backward(retain_graph=True)
            optimizer_g.step()
            optimizer_f.step()
            zero_grad_all()
            if not args.method == 'S+T':
                output = G(im_data_tu)
                if args.method == 'ENT':
                    loss_t = entropy(F1, output, args.lamda)
                    #print(loss_t.cpu().data.item())
                    loss_t.backward()
                    optimizer_f.step()
                    optimizer_g.step()
                elif args.method == 'MME':
                    loss_t = adentropy(F1, output, args.lamda,
                                       class_dist_threshold_list)
                    loss_t.backward()
                    optimizer_f.step()
                    optimizer_g.step()
                else:
                    raise ValueError('Method cannot be recognized.')
                log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                            'Loss Classification: {:.6f} Reg: {:.6f} Loss T {:.6f} ' \
                            'Method {}\n'.format(args.source, args.target,
                                                step, lr, loss.data, reg_weight*reg_loss.data,
                                                -loss_t.data, args.method)
            else:
                log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                            'Loss Classification: {:.6f} Reg: {:.6f} Method {}\n'.\
                    format(args.source, args.target,
                        step, lr, loss.data, reg_weight * reg_loss.data,
                        args.method)
            G.zero_grad()
            F1.zero_grad()
            zero_grad_all()
            if step % args.log_interval == 0:
                print(log_train)
            if step % args.save_interval == 0 and step > 0:
                loss_val, acc_val = test(target_loader_val)
                loss_test, acc_test = test(target_loader_test)
                G.train()
                F1.train()
                if acc_val >= best_acc:
                    best_acc = acc_val
                    best_acc_test = acc_test
                    counter = 0
                else:
                    counter += 1
                if args.early:
                    if counter > args.patience:
                        break
                print('best acc test %f best acc val %f' %
                      (best_acc_test, acc_val))
                print('record %s' % record_file)
                with open(record_file, 'a') as f:
                    f.write('step %d best %f final %f \n' %
                            (step, best_acc_test, acc_val))
                G.train()
                F1.train()
                #saving model as a checkpoint dict having many things
                if args.save_check:
                    print('saving model')
                    is_best = True if counter == 0 else False
                    save_mymodel(
                        args, {
                            'step': step,
                            'arch': args.net,
                            'G_state_dict': G.state_dict(),
                            'F1_state_dict': F1.state_dict(),
                            'best_acc_test': best_acc_test,
                            'optimizer_g': optimizer_g.state_dict(),
                            'optimizer_f': optimizer_f.state_dict(),
                        }, is_best, time_stamp)
예제 #2
0
def train(device, opt):
    G.train()
    F1.train()

    optimizer_g = optim.SGD(params, momentum=0.9, weight_decay=0.0005, nesterov=True)
    optimizer_f = optim.SGD(list(F1.parameters()), lr=args.lr_f, momentum=0.9, weight_decay=0.0005, nesterov=True)

    def zero_grad_all():
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()

    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f = []
    for param_group in optimizer_f.param_groups:
        param_lr_f.append(param_group["lr"])

    all_step = args.steps
    data_iter_s = iter(source_loader)
    data_iter_t = iter(target_loader)
    data_iter_t_unl = iter(target_loader_unl)
    len_train_source = len(source_loader)
    len_train_target = len(target_loader)
    len_train_target_semi = len(target_loader_unl)

    best_acc = 0

    BCE = BCE_softlabels().to(device)
    criterion = nn.CrossEntropyLoss().to(device)

    start_time = time.time()
    for step in range(all_step):

        rampup = sigmoid_rampup(step, args.rampup_length)
        w_cons = args.rampup_coef * rampup

        optimizer_g = inv_lr_scheduler(param_lr_g, optimizer_g, step,
                                       init_lr=args.lr)
        optimizer_f = inv_lr_scheduler(param_lr_f, optimizer_f, step,
                                       init_lr=args.lr)
        lr_f = optimizer_f.param_groups[0]['lr']
        lr_g = optimizer_g.param_groups[0]['lr']

        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)
        if step % len_train_target_semi == 0:
            data_iter_t_unl = iter(target_loader_unl)
        if step % len_train_source == 0:
            data_iter_s = iter(source_loader)
        data_t = next(data_iter_t)
        data_t_unl = next(data_iter_t_unl)
        data_s = next(data_iter_s)

        # load labeled source data
        x_s, target_s = data_s[0], data_s[1]
        im_data_s = x_s.to(device)
        gt_labels_s = target_s.to(device)

        # load labeled target data
        x_t, target_t = data_t[0], data_t[1]
        im_data_t = x_t.to(device)
        gt_labels_t = target_t.to(device)

        # load unlabeled target data
        x_tu, x_bar_tu, x_bar2_tu = data_t_unl[0], data_t_unl[3], data_t_unl[4]
        im_data_tu = x_tu.to(device)
        im_data_bar_tu = x_bar_tu.to(device)
        im_data_bar2_tu = x_bar2_tu.to(device)

        zero_grad_all()
        # construct losses for overall labeled data
        data = torch.cat((im_data_s, im_data_t), 0)
        target = torch.cat((gt_labels_s, gt_labels_t), 0)
        output = G(data)  # [batchsize, num_classes]
        out1 = F1(output)  # [batchsize, ]
        ce_loss = criterion(out1, target)

        ce_loss.backward(retain_graph=True)
        optimizer_g.step()
        optimizer_f.step()
        zero_grad_all()

        # construct losses for unlabeled target data
        aac_loss, pl_loss, con_loss = get_losses_unlabeled(args, G, F1, im_data=im_data_tu, im_data_bar=im_data_bar_tu,
                                                           im_data_bar2=im_data_bar2_tu, target=None, BCE=BCE,
                                                           w_cons=w_cons, device=device)
        loss = aac_loss + pl_loss + con_loss

        loss.backward()
        optimizer_g.step()
        optimizer_f.step()

        G.zero_grad()
        F1.zero_grad()
        zero_grad_all()

        if step % args.log_interval == 0:
            log_train = 'S {} T {} Train Ep: {} lr_f{:.6f} lr_g{:.6f}\n'.format(args.source, args.target,
                                                                                step, lr_f, lr_g)
            print(log_train)
            with open(opt["logs_file"], 'a') as f:
                f.write(log_train)

        if (step % args.save_interval) == 0 and step > 0 or (step == all_step - 1):
            loss_test, acc_test = test(target_loader_test)
            loss_val, acc_val = test(target_loader_val)
            G.train()
            F1.train()
            if acc_val >= best_acc:
                best_acc = acc_val
                best_acc_test = acc_test

            cur_time = time.time() - start_time
            print('Current acc test %f best acc test %f best acc val %f time cost %f sec.' % (
            acc_test, best_acc_test, acc_val, cur_time))

            with open(opt["logs_file"], 'a') as f:
                f.write('step %d current %f best %f val %f time cost %f sec.\n\n' % (
                step, acc_test, best_acc_test, acc_val, cur_time))

            G.train()
            F1.train()
            if args.save_check:
                print('Saving model')
                torch.save(G.state_dict(), os.path.join(opt["checkpath"],
                                                        "G_iter_model_{}_to_{}_step_{}.pth.tar".format(args.source,
                                                                                                       args.target,
                                                                                                       step)))
                torch.save(F1.state_dict(), os.path.join(opt["checkpath"],
                                                         "F1_iter_model_{}_to_{}_step_{}.pth.tar".format(args.source,
                                                                                                         args.target,
                                                                                                         step)))
            start_time = time.time()
예제 #3
0
def train():
    G.train()
    F1.train()
    F2.train()
    optimizer_g = optim.SGD(params,
                            lr=args.multi,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)
    optimizer_f1 = optim.SGD(list(F1.parameters()),
                             lr=1.0,
                             momentum=0.9,
                             weight_decay=0.0005,
                             nesterov=True)
    optimizer_f2 = optim.SGD(list(F2.parameters()),
                             lr=1.0,
                             momentum=0.9,
                             weight_decay=0.0005,
                             nesterov=True)

    def zero_grad_all():
        optimizer_g.zero_grad()
        optimizer_f1.zero_grad()
        optimizer_f2.zero_grad()

    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f1 = []
    for param_group in optimizer_f1.param_groups:
        param_lr_f1.append(param_group["lr"])
    param_lr_f2 = []
    for param_group in optimizer_f2.param_groups:
        param_lr_f2.append(param_group["lr"])
    criterion = nn.CrossEntropyLoss().cuda()
    all_step = args.steps
    data_iter_s = iter(source_loader)
    data_iter_t = iter(target_loader)
    data_iter_t_unl = iter(target_loader_unl)
    len_train_source = len(source_loader)
    len_train_target = len(target_loader)
    len_train_target_semi = len(target_loader_unl)
    best_acc = 0
    counter = 0
    for step in range(all_step):
        optimizer_g = inv_lr_scheduler(param_lr_g,
                                       optimizer_g,
                                       step,
                                       init_lr=args.lr)
        optimizer_f1 = inv_lr_scheduler(param_lr_f1,
                                        optimizer_f1,
                                        step,
                                        init_lr=args.lr)
        optimizer_f2 = inv_lr_scheduler(param_lr_f2,
                                        optimizer_f2,
                                        step,
                                        init_lr=args.lr)
        lr = optimizer_f1.param_groups[0]['lr']
        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)
        if step % len_train_target_semi == 0:
            data_iter_t_unl = iter(target_loader_unl)
        if step % len_train_source == 0:
            data_iter_s = iter(source_loader)
        data_t = next(data_iter_t)
        data_t_unl = next(data_iter_t_unl)
        data_s = next(data_iter_s)
        im_data_s.resize_(data_s[0].size()).copy_(data_s[0])
        gt_labels_s.resize_(data_s[1].size()).copy_(data_s[1])
        im_data_t.resize_(data_t[0].size()).copy_(data_t[0])
        gt_labels_t.resize_(data_t[1].size()).copy_(data_t[1])
        im_data_tu_weak.resize_(data_t_unl[0][0].size()).copy_(
            data_t_unl[0][0])
        im_data_tu_strong.resize_(data_t_unl[0][1].size()).copy_(
            data_t_unl[0][1])
        zero_grad_all()
        data = torch.cat(
            (im_data_s, im_data_t, im_data_tu_weak, im_data_tu_strong), 0)
        # target = torch.cat((gt_labels_s, gt_labels_t), 0)
        output = G(data)
        output_s = output[:len(im_data_s)]
        output_t = output[len(im_data_s):len(im_data_s) + len(im_data_t)]
        output_tu = output[len(im_data_s) + len(im_data_t):]
        output_tu_weak, output_tu_strong = output_tu.chunk(2)
        #out_1s = F1(output_s)
        out_2t = F2(output_t)

        out_2tu_weak = F2(output_tu_weak)
        out_2tu_strong = F2(output_tu_strong)
        pseudo_label_tu = torch.softmax(out_2tu_weak.detach_(), dim=-1)
        max_probs, targets_tu = torch.max(pseudo_label_tu, dim=-1)
        mask = max_probs.ge(args.threshold).float()

        #### Supervised Loss
        loss_x = criterion(out_2t,
                           gt_labels_t)  # + criterion(out_1s, gt_labels_s)

        ## Unsupervised Loss
        loss_u = (
            F.cross_entropy(out_2tu_strong, targets_tu, reduction='none') *
            mask).mean()

        loss = loss_x + 1.0 * loss_u

        loss.backward()
        optimizer_g.step()
        optimizer_f1.step()
        optimizer_f2.step()
        zero_grad_all()



        log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                    'Loss_x Classification: {:.6f} Loss_u Classification: {:.6f}\n'.format(args.source, args.target,
                                         step, lr, loss_x.data, loss_u.data)
        G.zero_grad()
        F1.zero_grad()
        F2.zero_grad()
        zero_grad_all()
        if step % args.log_interval == 0:
            print(log_train)
        if step % args.save_interval == 0 and step > 0:
            loss_test, acc_test = test(target_loader_test)
            loss_val, acc_val = test(target_loader_val)
            G.train()
            F1.train()
            F2.train()
            if acc_val >= best_acc:
                best_acc = acc_val
                best_acc_test = acc_test
                counter = 0
            else:
                counter += 1
            if args.early:
                if counter > args.patience:
                    break
            print('best acc test %f best acc val %f' %
                  (best_acc_test, best_acc))
            print('record %s' % record_file)
            with open(record_file, 'a') as f:
                f.write('step %d best %f final %f \n' %
                        (step, best_acc_test, best_acc))
            G.train()
            F1.train()
            F2.train()
            if args.save_check:
                print('saving model')
                torch.save(
                    G.state_dict(),
                    os.path.join(
                        args.checkpath, "G_iter_model_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.method,
                                                       args.source,
                                                       args.target, step)))
                torch.save(
                    F2.state_dict(),
                    os.path.join(
                        args.checkpath, "F2_iter_model_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.method,
                                                       args.source,
                                                       args.target, step)))
예제 #4
0
def train():
    G.train()
    F1.train()
    F2.train()
    optimizer_g = optim.SGD(params,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)
    optimizer_f1 = optim.SGD(list(F1.parameters()),
                             lr=1.0,
                             momentum=0.9,
                             weight_decay=0.0005,
                             nesterov=True)
    optimizer_f2 = optim.SGD(list(F2.parameters()),
                             lr=1.0,
                             momentum=0.9,
                             weight_decay=0.0005,
                             nesterov=True)

    def zero_grad_all():
        optimizer_g.zero_grad()
        optimizer_f1.zero_grad()
        optimizer_f2.zero_grad()

    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f1 = []
    for param_group in optimizer_f1.param_groups:
        param_lr_f1.append(param_group["lr"])
    param_lr_f2 = []
    for param_group in optimizer_f2.param_groups:
        param_lr_f2.append(param_group["lr"])
    criterion = nn.CrossEntropyLoss().cuda()
    all_step = args.steps
    data_iter_s = iter(source_loader)
    data_iter_t = iter(target_loader)
    data_iter_t_unl = iter(target_loader_unl)
    len_train_source = len(source_loader)
    len_train_target = len(target_loader)
    len_train_target_semi = len(target_loader_unl)
    best_acc = 0
    counter = 0
    end = time.time()
    for step in range(all_step):
        optimizer_g = inv_lr_scheduler(param_lr_g,
                                       optimizer_g,
                                       step,
                                       init_lr=args.lr)
        optimizer_f1 = inv_lr_scheduler(param_lr_f1,
                                        optimizer_f1,
                                        step,
                                        init_lr=args.lr)
        optimizer_f2 = inv_lr_scheduler(param_lr_f2,
                                        optimizer_f2,
                                        step,
                                        init_lr=args.lr)
        lr = optimizer_f1.param_groups[0]['lr']
        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)
        if step % len_train_target_semi == 0:
            data_iter_t_unl = iter(target_loader_unl)
        if step % len_train_source == 0:
            data_iter_s = iter(source_loader)
        data_t = next(data_iter_t)
        data_t_unl = next(data_iter_t_unl)
        data_s = next(data_iter_s)
        im_data_s = data_s[0].cuda()
        im_data_s = im_data_s.reshape(-1, im_data_s.shape[2],
                                      im_data_s.shape[3], im_data_s.shape[4])
        im_data_s_strong = data_s[1].cuda()
        gt_labels_s = data_s[2].cuda()
        gt_labels_s = torch.transpose(gt_labels_s, 1, 2)
        gt_labels_s = gt_labels_s.reshape(
            gt_labels_s.shape[0] * gt_labels_s.shape[1], gt_labels_s.shape[2])
        im_data_t = data_t[0].cuda()
        im_data_t = im_data_t.reshape(-1, im_data_t.shape[2],
                                      im_data_t.shape[3], im_data_t.shape[4])
        im_data_t_strong = data_t[1].cuda()
        gt_labels_t = data_t[2].cuda()
        gt_labels_t = torch.transpose(gt_labels_t, 1, 2)
        gt_labels_t = gt_labels_t.reshape(
            gt_labels_t.shape[0] * gt_labels_t.shape[1], gt_labels_t.shape[2])
        im_data_tu = data_t_unl[0].cuda()
        im_data_tu = im_data_tu.reshape(-1, im_data_tu.shape[2],
                                        im_data_tu.shape[3],
                                        im_data_tu.shape[4])
        im_data_tu_strong = data_t_unl[1].cuda()
        gt_labels_tu = data_t_unl[2].cuda()
        gt_labels_tu = torch.transpose(gt_labels_tu, 1, 2)
        gt_labels_tu = gt_labels_tu.reshape(
            gt_labels_tu.shape[0] * gt_labels_tu.shape[1],
            gt_labels_tu.shape[2])
        zero_grad_all()
        data = torch.cat((im_data_s, im_data_t, im_data_tu), 0)
        zero_grad_all()

        output = G(data)
        output_s = output[:len(im_data_s)]
        output_t = output[len(im_data_s):len(im_data_s) + len(im_data_t)]
        output_tu = output[len(im_data_s) + len(im_data_t):]

        #### Supervised Loss for unrotated images
        output_s_no_rot = output_s.index_select(
            0,
            torch.arange(0, len(output_s), 4).cuda())
        output_t_no_rot = output_t.index_select(
            0,
            torch.arange(0, len(output_t), 4).cuda())
        gt_labels_s_cls = gt_labels_s[:, 0].index_select(
            0,
            torch.arange(0, len(output_s), 4).cuda())
        gt_labels_t_cls = gt_labels_t[:, 0].index_select(
            0,
            torch.arange(0, len(output_t), 4).cuda())
        logits_l_cls = F1(torch.cat((output_s_no_rot, output_t_no_rot), 0))
        target_l_cls = torch.cat((gt_labels_s_cls, gt_labels_t_cls), 0)

        loss_x = criterion(logits_l_cls, target_l_cls)

        ## Unsupervised Loss
        output_tu_no_rot = output_tu.index_select(
            0,
            torch.arange(0, len(output_tu), 4).cuda())
        logits_tu_weak = F1(output_tu_no_rot)
        pseudo_label_tu = torch.softmax(logits_tu_weak.detach_(), dim=-1)
        max_probs, targets_tu = torch.max(pseudo_label_tu, dim=-1)

        mask = max_probs.ge(args.threshold).float().repeat(3)

        x_tu = torch.cat((im_data_tu_strong, im_data_tu_strong), 0)
        x_st = torch.cat((im_data_s_strong, im_data_t_strong), 0)
        y_tu = torch.cat((targets_tu, targets_tu), 0)
        y_st = torch.cat((gt_labels_s_cls, gt_labels_t_cls), 0)

        mixed_x_strong, y_a, y_b, lam = mixup_data(x_tu,
                                                   x_st,
                                                   y_tu,
                                                   y_st,
                                                   alpha=1.0)
        logits_mix_strong = F1(G(mixed_x_strong))
        loss_u = lam * (F.cross_entropy(logits_mix_strong, y_a, reduction='none') * mask).mean() + \
                 (1 - lam) * (F.cross_entropy(logits_mix_strong, y_b, reduction='none') * mask).mean()

        ### Rotation Self-supervised Loss
        logits_ul_rot = F2(torch.cat((output_s, output_t, output_tu), 0))
        target_ul_rot = torch.cat(
            (gt_labels_s[:, 1], gt_labels_t[:, 1], gt_labels_tu[:, 1]), 0)
        loss_rot = criterion(logits_ul_rot, target_ul_rot.cuda())

        ### Overall Loss
        loss = loss_x + loss_u + 0.7 * loss_rot

        loss.backward()
        optimizer_g.step()
        optimizer_f1.step()
        optimizer_f2.step()
        zero_grad_all()

        G.zero_grad()
        F1.zero_grad()
        F2.zero_grad()
        zero_grad_all()

        if step % args.log_interval == 0:
            time_elapse = time.time() - end
            log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                        'loss: {:.6f} loss_x: {:.6f} ' \
                        'loss_u {:.6f} loss_rot: {:.6f} time_elapse: {:.3f}s ' \
                        'Method {}\n'. \
                format(args.source, args.target,
                       step, lr, loss.data, loss_x.data,
                       loss_u.data, loss_rot.data, time_elapse,
                       args.method)
            end = time.time()
            print(log_train)
            with open(record_file, 'a') as f:
                f.write(log_train)
        if step % args.save_interval == 0 and step > 0:
            loss_test, acc_test = test(target_loader_test)
            loss_val, acc_val = test(target_loader_val)
            G.train()
            F1.train()
            F2.train()
            if acc_val >= best_acc:
                best_acc = acc_val
                best_acc_test = acc_test
                counter = 0
            else:
                counter += 1
            if args.early:
                if counter > args.patience:
                    break
            print('best acc test %f best acc val %f' %
                  (best_acc_test, best_acc))
            print('record %s' % record_file)
            with open(record_file, 'a') as f:
                f.write('step %d best %f final %f \n' %
                        (step, best_acc_test, best_acc))
            G.train()
            F1.train()
            F2.train()
            if args.save_check:
                print('saving model')
                torch.save(
                    G.state_dict(),
                    os.path.join(
                        args.checkpath, "G_iter_model_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.method,
                                                       args.source,
                                                       args.target, step)))
                torch.save(
                    F1.state_dict(),
                    os.path.join(
                        args.checkpath, "F1_iter_model_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.method,
                                                       args.source,
                                                       args.target, step)))
                torch.save(
                    F2.state_dict(),
                    os.path.join(
                        args.checkpath, "F2_iter_model_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.method,
                                                       args.source,
                                                       args.target, step)))
예제 #5
0
    def train(self):
        self.G.train()
        self.F1.train()
        optimizer_g = optim.SGD(self.params,
                                momentum=0.9,
                                weight_decay=0.0005,
                                nesterov=True)
        optimizer_f = optim.SGD(list(self.F1.parameters()),
                                lr=1.0,
                                momentum=0.9,
                                weight_decay=0.0005,
                                nesterov=True)

        def zero_grad_all():
            optimizer_g.zero_grad()
            optimizer_f.zero_grad()

        param_lr_g = []
        for param_group in optimizer_g.param_groups:
            param_lr_g.append(param_group["lr"])
        param_lr_f = []
        for param_group in optimizer_f.param_groups:
            param_lr_f.append(param_group["lr"])

        criterion = nn.CrossEntropyLoss().cuda()
        all_step = self.args.steps
        data_iter_s = iter(self.source_loader)
        data_iter_t = iter(self.target_loader)
        data_iter_t_unl = iter(self.target_loader_unl)
        len_train_source = len(self.source_loader)
        len_train_target = len(self.target_loader)
        len_train_target_semi = len(self.target_loader_unl)
        for step in range(all_step):
            optimizer_g = inv_lr_scheduler(param_lr_g,
                                           optimizer_g,
                                           step,
                                           init_lr=self.args.lr)
            optimizer_f = inv_lr_scheduler(param_lr_f,
                                           optimizer_f,
                                           step,
                                           init_lr=self.args.lr)

            lr = optimizer_f.param_groups[0]['lr']
            if step % len_train_target == 0:
                data_iter_t = iter(self.target_loader)
            if step % len_train_target_semi == 0:
                data_iter_t_unl = iter(self.target_loader_unl)
            if step % len_train_source == 0:
                data_iter_s = iter(self.source_loader)
            data_t = next(data_iter_t)
            data_t_unl = next(data_iter_t_unl)
            data_s = next(data_iter_s)
            self.im_data_s.resize_(data_s[0].size()).copy_(data_s[0])
            self.gt_labels_s.resize_(data_s[1].size()).copy_(data_s[1])
            self.im_data_t.resize_(data_t[0].size()).copy_(data_t[0])
            self.gt_labels_t.resize_(data_t[1].size()).copy_(data_t[1])
            self.im_data_tu.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])
            zero_grad_all()
            data = torch.cat((self.im_data_s, self.im_data_t), 0)
            target = torch.cat((self.gt_labels_s, self.gt_labels_t), 0)
            output = self.G(data)
            out1 = self.F1(output)
            loss = criterion(out1, target)
            loss.backward(retain_graph=True)
            optimizer_g.step()
            optimizer_f.step()
            zero_grad_all()

            output = self.G(self.im_data_tu)
            loss_t = adentropy(self.F1, output, self.args.lamda)
            loss_t.backward()
            optimizer_f.step()
            optimizer_g.step()

            log_train = 'S {} T {} Train Ep: {} lr{} \t Loss Classification: {:.6f} Method {}\n'.format(
                self.args.source, self.args.target, step, lr, loss.data,
                self.args.method)
            self.G.zero_grad()
            self.F1.zero_grad()

            if step % self.args.log_interval == 0:
                print(log_train)
            if step % self.args.save_interval == 0 and step > 0:
                self.test(self.target_loader_unl)
                self.G.train()
                self.F1.train()
                if self.args.save_check:
                    print('saving model')
                    torch.save(
                        self.G.state_dict(),
                        os.path.join(
                            self.args.checkpath,
                            "G_iter_model_{}_{}_to_{}_step_{}.pth.tar".format(
                                self.args.method, self.args.source,
                                self.args.target, step)))
                    torch.save(
                        self.F1.state_dict(),
                        os.path.join(
                            self.args.checkpath,
                            "F1_iter_model_{}_{}_to_{}_step_{}.pth.tar".format(
                                self.args.method, self.args.source,
                                self.args.target, step)))
예제 #6
0
def train(args,weights=None):
    if os.path.exists(args.checkpath) == False:
        os.mkdir(args.checkpath)

    # 1. get dataset
    train_loader, val_loader, test_loader, class_list = return_dataset(args)

    # 2. generator
    if args.net == 'resnet50':
        G = ResBase50()
        inc = 2048
    elif args.net == 'resnet101':
        G = ResBase101()
        inc = 2048        
    elif args.net == "alexnet":
        G = AlexNetBase()
        inc = 4096
    elif args.net == "vgg":
        G = VGGBase()
        inc = 4096
    elif args.net == "inception_v3":
        G = models.inception_v3(pretrained=True) 
        inc = 1000
    elif args.net == "googlenet":
        G = models.googlenet(pretrained = True)
        inc = 1000
    elif args.net == "densenet":
        G = models.densenet161(pretrained = True)
        inc = 1000
    elif args.net == "resnext":
        G = models.resnext50_32x4d(pretrained = True)
        inc = 1000
    elif args.net == "squeezenet":
        G = models.squeezenet1_0(pretrained = True)
        inc = 1000
    else:
        raise ValueError('Model cannot be recognized.')

    params = []
    for key, value in dict(G.named_parameters()).items():
        if value.requires_grad:
            if 'classifier' not in key:
                params += [{'params': [value], 'lr': args.multi,
                            'weight_decay': 0.0005}]
            else:
                params += [{'params': [value], 'lr': args.multi * 10,
                            'weight_decay': 0.0005}]
    G.cuda()
    G.train()

    # 3. classifier
    F = Predictor(num_class=len(class_list), inc=inc, temp=args.T)
    weights_init(F)
    F.cuda()
    F.train()  

    # 4. optimizer
    optimizer_g = optim.SGD(params, momentum=0.9,
                            weight_decay=0.0005, nesterov=True)
    optimizer_f = optim.SGD(list(F.parameters()), lr=1.0, momentum=0.9,
                            weight_decay=0.0005, nesterov=True) 
    optimizer_g.zero_grad()
    optimizer_f.zero_grad()

    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f = []
    for param_group in optimizer_f.param_groups:
        param_lr_f.append(param_group["lr"])

    # 5. training
    data_iter_train = iter(train_loader)
    len_train = len(train_loader)
    best_acc = 0
    for step in range(args.steps):
        # update optimizer and lr
        optimizer_g = inv_lr_scheduler(param_lr_g, optimizer_g, step,
                                       init_lr=args.lr)
        optimizer_f = inv_lr_scheduler(param_lr_f, optimizer_f, step,
                                       init_lr=args.lr)
        lr = optimizer_f.param_groups[0]['lr']
        if step % len_train == 0:
            data_iter_train = iter(train_loader)

        # forwarding
        data = next(data_iter_train)        
        im_data = data[0].cuda()
        gt_label = data[1].cuda()

        feature = G(im_data)
        if args.net == 'inception_v3': #its not a tensor output but some 'inceptionOutput' object
          feature = feature.logits #get the tensor object
        if args.loss=='CrossEntropy': 
            #call with weights if present 
            loss = crossentropy(F, feature, gt_label, None if (weights == None) else weights[step % len_train])
            #although the weights might be defaulting to none
        elif args.loss=='FocalLoss':
            loss = focal_loss(F, feature, gt_label, None if (weights == None) else weights[step % len_train])
        elif args.loss=='ASoftmaxLoss':
            loss = asoftmax_loss(F, feature, gt_label, None if (weights == None) else weights[step % len_train])
        elif args.loss=='SmoothCrossEntropy':
            loss = smooth_crossentropy(F, feature, gt_label, None if (weights == None) else weights[step % len_train])
        else:
            print('To add new loss')         
        loss.backward()

        # backpropagation
        optimizer_g.step()
        optimizer_f.step()       
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()
        G.zero_grad()
        F.zero_grad()

        if step%args.log_interval==0:
            log_train = 'Train iter: {} lr{} Loss Classification: {:.6f}\n'.format(step, lr, loss.data)
            print(log_train)
        if step and step%args.save_interval==0:
            # evaluate and save
            acc_val = eval(val_loader, G, F, class_list)     
            G.train()  
            F.train()  
            if args.save_check and acc_val >= best_acc:
                best_acc = acc_val
                print('saving model')
                print('best_acc: '+str(best_acc) + '  acc_val: '+str(acc_val))
                torch.save(G.state_dict(), os.path.join(args.checkpath,
                                        "G_net_{}_loss_{}.pth".format(args.net, args.loss)))
                torch.save(F.state_dict(), os.path.join(args.checkpath,
                                        "F_net_{}_loss_{}.pth".format(args.net, args.loss)))
    if (weights is not None):
      print("computing error rate")
      error_rate = eval_adaboost_error_rate(train_loader, G, F, class_list, weights)
      model_importance = torch.log((1-error_rate)/error_rate)/2
      #now update the weights
      print("updating weights")
      update_weights_adaboost(train_loader, G, F, class_list, weights, model_importance)
      return error_rate, model_importance
예제 #7
0
파일: main_rot.py 프로젝트: danishmsin/SSAL
def train():
    G.train()
    F1.train()
    optimizer_g = optim.SGD(params,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)
    optimizer_f = optim.SGD(list(F1.parameters()),
                            lr=1.0,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)

    def zero_grad_all():
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()

    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f = []
    for param_group in optimizer_f.param_groups:
        param_lr_f.append(param_group["lr"])
    criterion = nn.CrossEntropyLoss().cuda()
    all_step = args.steps

    data_iter_t = iter(target_loader)
    len_train_target = len(target_loader)
    best_acc = 0
    counter = 0
    for step in range(all_step):
        optimizer_g = inv_lr_scheduler(param_lr_g,
                                       optimizer_g,
                                       step,
                                       init_lr=args.lr)
        optimizer_f = inv_lr_scheduler(param_lr_f,
                                       optimizer_f,
                                       step,
                                       init_lr=args.lr)
        lr = optimizer_f.param_groups[0]['lr']
        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)

        data_t = next(data_iter_t)

        im_data_t.data.resize_(data_t[0].size()).copy_(data_t[0])
        gt_labels_t.data.resize_(data_t[1].size()).copy_(data_t[1])

        zero_grad_all()
        data = im_data_t
        target = gt_labels_t
        output = G(data)
        out1 = F1(output)

        loss = criterion(out1, target)
        loss.backward(retain_graph=True)
        optimizer_g.step()
        optimizer_f.step()
        zero_grad_all()

        log_train = 'T {} Train Ep: {} lr{} \t Loss Classification: {:.6f} \n'.format(
            args.target, step, lr, loss.data)

        G.zero_grad()
        F1.zero_grad()
        zero_grad_all()

        if step % args.log_interval == 0:
            print(log_train)
        if step % args.save_interval == 0 and step > 0:
            loss_train, acc_train = test(target_loader)

            G.train()
            F1.train()
            if acc_train >= best_acc:
                best_acc = acc_train
                counter = 0
            else:
                counter += 1
            if args.early:
                if counter > args.patience:
                    break
            print('best acc  %f' % (best_acc))
            print('record %s' % record_file)
            with open(record_file, 'a') as f:
                f.write('step %d best %f  \n' % (step, best_acc))

            G.train()
            F1.train()
            if args.save_check:
                print('saving model')
                torch.save(
                    G.state_dict(),
                    os.path.join(
                        args.checkpath,
                        "G_iter_model_{}_step_{}.pth.tar".format(
                            args.target, step)))

                torch.save(
                    F1.state_dict(),
                    os.path.join(
                        args.checkpath,
                        "F1_iter_model_{}_step_{}.pth.tar".format(
                            args.target, step)))
예제 #8
0
def train():
    G.train()
    F1.train()
    F2.train()
    optimizer_g = optim.SGD(params,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)
    optimizer_f1 = optim.SGD(list(F1.parameters()),
                             lr=1.0,
                             momentum=0.9,
                             weight_decay=0.0005,
                             nesterov=True)
    optimizer_f2 = optim.SGD(list(F2.parameters()),
                             lr=1.0,
                             momentum=0.9,
                             weight_decay=0.0005,
                             nesterov=True)

    # optimizer_g = optim.Adam(params)
    # optimizer_f1 = optim.Adam(list(F1.parameters()))
    # optimizer_f2 = optim.Adam(list(F2.parameters()))

    def zero_grad_all():
        optimizer_g.zero_grad()
        optimizer_f1.zero_grad()
        optimizer_f2.zero_grad()

    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f1 = []
    for param_group in optimizer_f1.param_groups:
        param_lr_f1.append(param_group["lr"])
    param_lr_f2 = []
    for param_group in optimizer_f2.param_groups:
        param_lr_f2.append(param_group["lr"])
    criterion = nn.CrossEntropyLoss().cuda()
    all_step = args.steps

    data_iter_t = iter(target_loader)
    data_iter_t_unl = iter(target_loader_unl)
    data_iter_s = iter(source_loader)
    len_train_source = len(source_loader)
    len_train_target = len(target_loader)
    len_train_target_semi = len(target_loader_unl)
    best_acc = 0
    counter = 0
    for step in range(all_step):
        optimizer_g = inv_lr_scheduler(param_lr_g,
                                       optimizer_g,
                                       step,
                                       init_lr=args.lr)
        optimizer_f1 = inv_lr_scheduler(param_lr_f1,
                                        optimizer_f1,
                                        step,
                                        init_lr=args.lr)
        optimizer_f2 = inv_lr_scheduler(param_lr_f2,
                                        optimizer_f2,
                                        step,
                                        init_lr=args.lr)
        lr = optimizer_f1.param_groups[0]['lr']
        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)
        if step % len_train_target_semi == 0:
            data_iter_t_unl = iter(target_loader_unl)
        if step % len_train_source == 0:
            data_iter_s = iter(source_loader)

        data_s = next(data_iter_s)
        data_t = next(data_iter_t)
        data_t_unl = next(data_iter_t_unl)

        # im_data_s.resize_(data_s[0].size()).copy_(data_s[0])
        # gt_labels_s.resize_(data_s[1].size()).copy_(data_s[1])
        # gt_labels_s.transpose_(1, 2)
        # gt_labels_s.resize_(gt_labels_s.shape[0] * gt_labels_s.shape[1], gt_labels_s.shape[2])
        # im_data_t.resize_(data_t[0].size()).copy_(data_t[0])
        # gt_labels_t.resize_(data_t[1].size()).copy_(data_t[1])
        # gt_labels_t.transpose_(1, 2)
        # gt_labels_t.resize_(gt_labels_t.shape[0] * gt_labels_t.shape[1], gt_labels_t.shape[2])
        # im_data_tu.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])
        # gt_labels_tu.resize_(data_t_unl[1].size()).copy_(data_t_unl[1])
        # gt_labels_tu.transpose_(1, 2)
        # gt_labels_tu.resize_(gt_labels_tu.shape[0] * gt_labels_tu.shape[1], gt_labels_tu.shape[2])
        im_data_s = data_s[0].cuda()
        im_data_s = im_data_s.reshape(-1, im_data_s.shape[2],
                                      im_data_s.shape[3], im_data_s.shape[4])
        gt_labels_s = data_s[1].cuda()
        gt_labels_s = torch.transpose(gt_labels_s, 1, 2)
        gt_labels_s = gt_labels_s.reshape(
            gt_labels_s.shape[0] * gt_labels_s.shape[1], gt_labels_s.shape[2])
        im_data_t = data_t[0].cuda()
        im_data_t = im_data_t.reshape(-1, im_data_t.shape[2],
                                      im_data_t.shape[3], im_data_t.shape[4])
        gt_labels_t = data_t[1].cuda()
        gt_labels_t = torch.transpose(gt_labels_t, 1, 2)
        gt_labels_t = gt_labels_t.reshape(
            gt_labels_t.shape[0] * gt_labels_t.shape[1], gt_labels_t.shape[2])
        im_data_tu = data_t_unl[0].cuda()
        im_data_tu = im_data_tu.reshape(-1, im_data_tu.shape[2],
                                        im_data_tu.shape[3],
                                        im_data_tu.shape[4])
        gt_labels_tu = data_t_unl[1].cuda()
        gt_labels_tu = torch.transpose(gt_labels_tu, 1, 2)
        gt_labels_tu = gt_labels_tu.reshape(
            gt_labels_tu.shape[0] * gt_labels_tu.shape[1],
            gt_labels_tu.shape[2])
        zero_grad_all()
        data = torch.cat((im_data_s, im_data_t, im_data_tu), 0)
        # target = torch.cat((gt_labels_s, gt_labels_t), 0)
        output = G(data)
        output_s = output[:len(im_data_s)]
        output_t = output[len(im_data_s):len(im_data_s) + len(im_data_t)]
        output_tu = output[len(im_data_s) + len(im_data_t):]

        #### Supervised Loss for unrotated images
        output_s_no_rot = output_s.index_select(
            0,
            torch.arange(0, len(output_s), 4).cuda())
        output_t_no_rot = output_t.index_select(
            0,
            torch.arange(0, len(output_t), 4).cuda())
        gt_labels_s_cls = gt_labels_s[:, 0].index_select(
            0,
            torch.arange(0, len(output_s), 4).cuda())
        gt_labels_t_cls = gt_labels_t[:, 0].index_select(
            0,
            torch.arange(0, len(output_t), 4).cuda())
        out_l = F1(torch.cat((output_s_no_rot, output_t_no_rot), 0))
        target_l = torch.cat((gt_labels_s_cls, gt_labels_t_cls), 0)

        loss_x = criterion(out_l, target_l)

        ### Unsupervised Loss
        out_ul = F2(torch.cat((output_s, output_t, output_tu), 0))
        target_ul = torch.cat(
            (gt_labels_s[:, 1], gt_labels_t[:, 1], gt_labels_tu[:, 1]), 0)

        loss_rot = criterion(out_ul, target_ul.cuda())
        loss = loss_x + 1.0 * loss_rot

        loss.backward(retain_graph=True)
        optimizer_g.step()
        optimizer_f1.step()
        optimizer_f2.step()
        zero_grad_all()


        log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                    'loss: {:.6f} loss_x: {:.6f} ' \
                    'loss_rot {:.6f}\n'.format(args.source, args.target,
                                         step, lr, loss.data, loss_x.data,
                                         loss_rot.data)
        G.zero_grad()
        F1.zero_grad()
        F2.zero_grad()
        zero_grad_all()
        if step % args.log_interval == 0:
            print(log_train)
        if step % args.save_interval == 0 and step > 0:
            loss_test, acc_test = test(target_loader_test)
            loss_val, acc_val = test(target_loader_val)
            G.train()
            F1.train()
            F2.train()
            if acc_val >= best_acc:
                best_acc = acc_val
                best_acc_test = acc_test
                counter = 0
            else:
                counter += 1
            if args.early:
                if counter > args.patience:
                    break
            print('best acc test %f best acc val %f' %
                  (best_acc_test, best_acc))
            print('record %s' % record_file)
            with open(record_file, 'a') as f:
                f.write('step %d best %f final %f \n' %
                        (step, best_acc_test, best_acc))
            G.train()
            F1.train()
            F2.train()
            if args.save_check:
                print('saving model')
                torch.save(
                    G.state_dict(),
                    os.path.join(
                        args.checkpath, "G_iter_model_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.method,
                                                       args.source,
                                                       args.target, step)))
                torch.save(
                    F2.state_dict(),
                    os.path.join(
                        args.checkpath, "F2_iter_model_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.method,
                                                       args.source,
                                                       args.target, step)))
예제 #9
0
파일: main.py 프로젝트: yuntaodu/APE
def train():
    G.train()
    F1.train()
    optimizer_g = optim.SGD(params,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)
    optimizer_f = optim.SGD(list(F1.parameters()),
                            lr=1.0,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)
    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f = []
    for param_group in optimizer_f.param_groups:
        param_lr_f.append(param_group["lr"])

    ################################################################################################################
    ################################################# train model ##################################################
    ################################################################################################################

    def zero_grad_all():
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()

    class AbstractConsistencyLoss(nn.Module):
        def __init__(self, reduction='mean'):
            super().__init__()
            self.reduction = reduction

        def forward(self, logits1, logits2):
            raise NotImplementedError

    class KLDivLossWithLogits(AbstractConsistencyLoss):
        def __init__(self, reduction='mean'):
            super().__init__(reduction)
            self.kl_div_loss = nn.KLDivLoss(reduction=reduction)

        def forward(self, logits1, logits2):
            return self.kl_div_loss(F.log_softmax(logits1, dim=1),
                                    F.softmax(logits2, dim=1))

    class EntropyLoss(nn.Module):
        def __init__(self, reduction='mean'):
            super().__init__()
            self.reduction = reduction

        def forward(self, logits):
            p = F.softmax(logits, dim=1)
            elementwise_entropy = -p * F.log_softmax(logits, dim=1)
            if self.reduction == 'none':
                return elementwise_entropy

            sum_entropy = torch.sum(elementwise_entropy, dim=1)
            if self.reduction == 'sum':
                return sum_entropy

            return torch.mean(sum_entropy)

    P = PerturbationGenerator(G, F1, xi=1, eps=25, ip=1)
    criterion = nn.CrossEntropyLoss().cuda()
    criterion_reduce = nn.CrossEntropyLoss(reduce=False).cuda()
    target_consistency_criterion = KLDivLossWithLogits(reduction='mean').cuda()
    criterion_entropy = EntropyLoss()

    all_step = args.steps
    data_iter_s = iter(source_loader)
    data_iter_t = iter(target_loader)
    data_iter_t_unl = iter(target_loader_unl)
    len_train_source = len(source_loader)
    len_train_target = len(target_loader)
    len_train_target_semi = len(target_loader_unl)
    best_acc = 0
    counter = 0
    if args.net == 'resnet34':
        thr = 0.5
    else:
        thr = 0.3

    for step in range(all_step):
        optimizer_g = inv_lr_scheduler(param_lr_g,
                                       optimizer_g,
                                       step,
                                       init_lr=args.lr)
        optimizer_f = inv_lr_scheduler(param_lr_f,
                                       optimizer_f,
                                       step,
                                       init_lr=args.lr)
        lr = optimizer_f.param_groups[0]['lr']

        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)
        if step % len_train_target_semi == 0:
            data_iter_t_unl = iter(target_loader_unl)
        if step % len_train_source == 0:
            data_iter_s = iter(source_loader)

        data_t = next(data_iter_t)
        data_t_unl = next(data_iter_t_unl)
        data_s = next(data_iter_s)

        im_data_s = Variable(data_s[0].cuda())
        gt_labels_s = Variable(data_s[1].cuda())
        im_data_t = Variable(data_t[0].cuda())
        gt_labels_t = Variable(data_t[1].cuda())
        im_data_tu = Variable(data_t_unl[0].cuda())
        gt_labels_tu = Variable(data_t_unl[1].cuda())
        gt_labels = torch.cat((gt_labels_s, gt_labels_t), 0)
        gt_dom_s = Variable(torch.zeros(im_data_s.size(0)).cuda().long())
        gt_dom_t = Variable(torch.ones(im_data_t.size(0)).cuda().long())
        gt_dom = torch.cat((gt_dom_s, gt_dom_t))
        zero_grad_all()

        ################################################################################################################
        ################################################# train model ##################################################
        ################################################################################################################
        data = torch.cat((im_data_s, im_data_t), 0)
        target = torch.cat((gt_labels_s, gt_labels_t), 0)
        sigma = [1, 2, 5, 10]

        output = G(data)
        output_tu = G(im_data_tu)
        latent_F1, out1 = F1(output)
        latent_F1_tu, out_F1_tu = F1(output_tu)

        # supervision loss
        loss = criterion(out1, target)

        # attraction scheme
        loss_msda = 10 * mmd.mix_rbf_mmd2(latent_F1, latent_F1_tu, sigma)

        # exploration scheme
        pred = out_F1_tu.data.max(1)[1].detach()
        ent = -torch.sum(
            F.softmax(out_F1_tu, 1) *
            (torch.log(F.softmax(out_F1_tu, 1) + 1e-5)), 1)
        mask_reliable = (ent < thr).float().detach()
        loss_cls_F1 = (mask_reliable * criterion_reduce(
            out_F1_tu, pred)).sum(0) / (1e-5 + mask_reliable.sum())

        (loss + loss_cls_F1 + loss_msda).backward(retain_graph=False)
        group_step([optimizer_g, optimizer_f])
        zero_grad_all()
        if step % 20 == 0:
            print('step %d' % step, 'loss_cls: {:.4f}'.format(loss.cpu().data), ' | ', 'loss_Attract: {:.4f}'.format(loss_msda.cpu().data), ' | ', \
                  'loss_Explore: {:.4f}'.format(loss_cls_F1.cpu().item()), end=' | ')

        # perturbation scheme
        bs = gt_labels_s.size(0)
        target_data = torch.cat((im_data_t, im_data_tu), 0)
        perturb, clean_vat_logits = P(target_data)
        perturb_inputs = target_data + perturb
        perturb_inputs = torch.cat(perturb_inputs.split(bs), 0)
        perturb_features = G(perturb_inputs)
        perturb_logits = F1(perturb_features)[0]
        target_vat_loss2 = 10 * target_consistency_criterion(
            perturb_logits, clean_vat_logits)

        target_vat_loss2.backward()
        group_step([optimizer_g, optimizer_f])
        zero_grad_all()

        if step % 20 == 0:
            print('loss_Perturb: {:.4f}'.format(target_vat_loss2.cpu().data))
        G.zero_grad()
        F1.zero_grad()
        zero_grad_all()

        if step % args.save_interval == 0 and step > 0:
            loss_test, acc_test = test(G, F1, target_loader_test)
            loss_val, acc_val = test(G, F1, target_loader_val)
            G.train()
            F1.train()

            if args.save_check:
                print('saving model')
                torch.save(
                    G.state_dict(),
                    os.path.join(
                        args.checkpath, "G_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.dataset,
                                                       args.source,
                                                       args.target, step)))
                torch.save(
                    F1.state_dict(),
                    os.path.join(
                        args.checkpath, "F1_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.dataset,
                                                       args.source,
                                                       args.target, step)))
예제 #10
0
def train():
    criterion = nn.CrossEntropyLoss().cuda()
    print('train start!')
    data_iter_s = iter(source_loader)
    data_iter_t = iter(target_loader)
    data_iter_t_l = iter(target_labeled_loader)
    len_train_source = len(source_loader)
    len_train_target = len(target_loader)
    len_train_target_l = len(target_labeled_loader)
    for step in range(conf.train.min_step + 1):
        G.train()
        C1.train()
        C2.train()
        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)
        if step % len_train_target_l == 0:
            data_iter_t_l = iter(target_labeled_loader)
        if step % len_train_source == 0:
            data_iter_s = iter(source_loader)
        data_t = next(data_iter_t)
        data_t_l = next(data_iter_t_l)
        data_s = next(data_iter_s)
        inv_lr_scheduler(param_lr_g, opt_g, step,
                         init_lr=conf.train.lr,
                         max_iter=conf.train.min_step)
        inv_lr_scheduler(param_lr_f, opt_c1, step,
                         init_lr=conf.train.lr,
                         max_iter=conf.train.min_step)
        img_s = data_s[0]
        label_s = data_s[1]
        img_t = data_t[0]
        index_t = data_t[2]
        img_s, label_s = Variable(img_s.cuda()), \
                         Variable(label_s.cuda())
        img_t = Variable(img_t.cuda())
        index_t = Variable(index_t.cuda())
        img_t_l = data_t_l[0].cuda()
        label_t_l = data_t_l[1].cuda()

        if len(img_t) < batch_size:
            break
        if len(img_s) < batch_size:
            break
        opt_g.zero_grad()
        opt_c1.zero_grad()
        ## Weight normalizztion
        C1.module.weight_norm()
        ## Source loss calculation
        feat = G(img_s)
        out_s = C1(feat)
        loss_s = criterion(out_s, label_s)
        #loss_s += criterion(C2(feat.detach()), label_s)

        feat_t = G(img_t)
        out_t = C1(feat_t)
        feat_t = F.normalize(feat_t)
        ## Train a linear classifier on top of feature extractor.
        ## We should not update feature extractor.
        G.eval()
        feat_t_l = G(img_t_l)
        G.train()
        out_t_l = C2(feat_t_l.detach())
        loss_t_l = criterion(out_t_l, label_t_l)

        ### Calculate mini-batch x memory similarity
        feat_mat = lemniscate(feat_t, index_t)
        ### We do not use memory features present in mini-batch
        feat_mat[:, index_t] = -1 / conf.model.temp
        ### Calculate mini-batch x mini-batch similarity
        feat_mat2 = torch.matmul(feat_t,
                                 feat_t.t()) / conf.model.temp
        mask = torch.eye(feat_mat2.size(0),
                         feat_mat2.size(0)).bool().cuda()
        feat_mat2.masked_fill_(mask, -1 / conf.model.temp)
        loss_nc = conf.train.eta * entropy(torch.cat([feat_mat,
                                                      feat_mat2], 1))
        loss_ent = conf.train.eta * entropy_margin(out_t, conf.train.thr,
                                                   conf.train.margin)
        all = loss_nc + loss_s + loss_t_l
        with amp.scale_loss(all, [opt_g, opt_c1]) as scaled_loss:
            scaled_loss.backward()
        opt_g.step()
        opt_c1.step()
        opt_g.zero_grad()
        opt_c1.zero_grad()
        lemniscate.update_weight(feat_t, index_t)
        if step % conf.train.log_interval == 0:
            print('Train [{}/{} ({:.2f}%)]\tLoss Source: {:.6f} '
                  'Loss NC: {:.6f} Loss LT: {:.6f}\t'.format(
                step, conf.train.min_step,
                100 * float(step / conf.train.min_step),
                loss_s.item(), loss_nc.item(), loss_t_l.item()))
        if step > 0 and step % conf.test.test_interval == 0:
            test(step, dataset_test, filename, n_share, num_class, G, C1,
                 conf.train.thr)
            test_class_inc(step, dataset_test, filename, n_target, G, C2,
                           n_share)
            G.train()
            C1.train()
            C2.train()
예제 #11
0
def train():
    G.train()
    F1.train()
    optimizer_g = optim.SGD(params,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)
    optimizer_f = optim.SGD(list(F1.parameters()),
                            lr=1.0,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)

    def zero_grad_all():
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()

    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f = []
    for param_group in optimizer_f.param_groups:
        param_lr_f.append(param_group["lr"])

    #criterion = nn.CrossEntropyLoss().cuda()
    beta = 0.99
    effective_num = 1.0 - np.power(beta, class_num_list)
    per_cls_weights = (1.0 - beta) / np.array(effective_num)
    per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
        class_num_list)
    per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()
    criterion = FocalLoss(weight=per_cls_weights, gamma=0.5).cuda()
    all_step = args.steps
    data_iter_s = iter(source_loader)
    data_iter_t = iter(target_loader)
    data_iter_t_unl = iter(target_loader_unl)
    len_train_source = len(source_loader)
    len_train_target = len(target_loader)
    len_train_target_semi = len(target_loader_unl)
    best_acc_test = 0
    counter = 0
    for step in range(all_step):
        optimizer_g = inv_lr_scheduler(param_lr_g,
                                       optimizer_g,
                                       step,
                                       init_lr=args.lr)
        optimizer_f = inv_lr_scheduler(param_lr_f,
                                       optimizer_f,
                                       step,
                                       init_lr=args.lr)
        lr = optimizer_f.param_groups[0]['lr']
        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)
        if step % len_train_target_semi == 0:
            data_iter_t_unl = iter(target_loader_unl)
        if step % len_train_source == 0:
            data_iter_s = iter(source_loader)
        data_t = next(data_iter_t)
        data_t_unl = next(data_iter_t_unl)
        data_s = next(data_iter_s)
        im_data_s.data.resize_(data_s[0].size()).copy_(data_s[0])
        gt_labels_s.data.resize_(data_s[1].size()).copy_(data_s[1])
        im_data_t.data.resize_(data_t[0].size()).copy_(data_t[0])
        gt_labels_t.data.resize_(data_t[1].size()).copy_(data_t[1])
        im_data_tu.data.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])
        zero_grad_all()
        data = torch.cat((im_data_s, im_data_t), 0)
        target = torch.cat((gt_labels_s, gt_labels_t), 0)
        output = G(data)
        out1 = F1(output)
        loss = criterion(out1, target)
        loss.backward(retain_graph=True)
        optimizer_g.step()
        optimizer_f.step()
        zero_grad_all()
        if not args.method == 'S+T':
            output = G(im_data_tu)
            if args.method == 'ENT':
                loss_t = entropy(F1, output, args.lamda)
                loss_t.backward()
                optimizer_f.step()
                optimizer_g.step()
            elif args.method == 'MME':
                loss_t = adentropy(F1, output, args.lamda)
                loss_t.backward()
                optimizer_f.step()
                optimizer_g.step()
            else:
                raise ValueError('Method cannot be recognized.')
            log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                        'Loss Classification: {:.6f} Loss T {:.6f} ' \
                        'Method {}\n'.format(args.source, args.target,
                                             step, lr, loss.data,
                                             -loss_t.data, args.method)
        else:
            log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                        'Loss Classification: {:.6f} Method {}\n'.\
                format(args.source, args.target,
                       step, lr, loss.data,
                       args.method)
        G.zero_grad()
        F1.zero_grad()
        zero_grad_all()
        if step % args.log_interval == 0:
            print(log_train)
        if step % args.save_interval == 0 and step > 0:
            loss_test, acc_test = test(target_loader_test)
            loss_val, acc_val = test(target_loader_val)
            G.train()
            F1.train()
            if acc_test > best_acc_test:
                best_acc = acc_val
                best_acc_test = acc_test
                counter = 0
            else:
                counter += 1
            if args.early:
                if counter > args.patience:
                    break
            print('best acc test %f best acc val %f' %
                  (best_acc_test, acc_val))
            print('record %s' % record_file)
            with open(record_file, 'a') as f:
                f.write('step %d best %f final %f \n' %
                        (step, best_acc_test, acc_val))
            G.train()
            F1.train()
            if args.save_check:
                print('saving model...')
                is_best = True if counter == 0 else False
                save_mymodel(
                    args, {
                        'step': step,
                        'arch': args.net,
                        'G_state_dict': G.state_dict(),
                        'F1_state_dict': F1.state_dict(),
                        'best_acc_test': best_acc_test,
                        'optimizer_g': optimizer_g.state_dict(),
                        'optimizer_f': optimizer_f.state_dict(),
                    }, is_best)
예제 #12
0
def train():
    print()
    print("entered train function")
    print()
    G.train()
    F1.train()
    print("optimizers")
    optimizer_g = optim.SGD(params,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)
    optimizer_f = optim.SGD(list(F1.parameters()),
                            lr=1.0,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)

    def zero_grad_all():
        print("setting gradients to 0")
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()

    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f = []
    for param_group in optimizer_f.param_groups:
        param_lr_f.append(param_group["lr"])
    #criterion = nn.CrossEntropyLoss().cuda()
    print("optimizing for min CrossEntropyLoss")
    criterion = nn.CrossEntropyLoss()
    all_step = args.steps
    data_iter_s = iter(source_loader1)
    data_iter_t = iter(target_loader)
    #data_iter_t_unl = iter(target_loader_unl)
    len_train_source = len(source_loader1)
    len_train_target = len(target_loader)
    #len_train_target_semi = len(target_loader_unl)

    print("source_loader1 ", len(source_loader1.dataset))
    print()
    print("target_loader ", len(target_loader.dataset))
    print()

    for step in range(all_step):
        print("optimization step: ", step)
        optimizer_g = inv_lr_scheduler(param_lr_g,
                                       optimizer_g,
                                       step,
                                       init_lr=args.lr)
        optimizer_f = inv_lr_scheduler(param_lr_f,
                                       optimizer_f,
                                       step,
                                       init_lr=args.lr)

        lr = optimizer_f.param_groups[0]['lr']
        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)
        #if step % len_train_target_semi == 0:
        #data_iter_t_unl = iter(target_loader_unl)
        if step % len_train_source == 0:
            data_iter_s = iter(source_loader1)
        data_t = next(data_iter_t)
        #data_t_unl = next(data_iter_t_unl)
        data_s = next(data_iter_s)
        with torch.no_grad():

            #comment out gpu requirements & commands for this execution
            #im_data_s.data.resize_(data_s[0].size()).copy_(data_s[0])
            #gt_labels_s.data.resize_(data_s[1].size()).copy_(data_s[1])
            #im_data_t.data.resize_(data_t[0].size()).copy_(data_t[0])
            #gt_labels_t.data.resize_(data_t[1].size()).copy_(data_t[1])
            #im_data_tu.data.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])

            im_data_s.resize_(data_s[0].size()).copy_(data_s[0])
            gt_labels_s.resize_(data_s[1].size()).copy_(data_s[1])
            im_data_t.resize_(data_t[0].size()).copy_(data_t[0])
            gt_labels_t.resize_(data_t[1].size()).copy_(data_t[1])
            #im_data_tu.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])
        zero_grad_all()
        data = torch.cat((im_data_s, im_data_t), 0)
        target = torch.cat((gt_labels_s, gt_labels_t), 0)
        output = G(data)
        out1 = F1(output)
        loss = criterion(out1, target)
        loss.backward(retain_graph=True)
        optimizer_g.step()
        optimizer_f.step()
        zero_grad_all()
        if not args.method == 'S+T':
            output = G(im_data_tu)
            if args.method == 'ENT':
                loss_t = entropy(F1, output, args.lamda)
                loss_t.backward()
                optimizer_f.step()
                optimizer_g.step()
            elif args.method == 'MME':
                print("MME")
                loss_t = adentropy(F1, output, args.lamda)
                loss_t.backward()
                optimizer_f.step()
                optimizer_g.step()
            else:
                raise ValueError('Method cannot be recognized.')
            log_train = 'S {} T {} Train Ep: {} lr{} \t Loss Classification: {:.6f} Loss T {:.6f} Method {}\n'.format(
                args.source, args.target, step, lr, loss.data, -loss_t.data,
                args.method)
        else:
            log_train = 'S {} T {} Train Ep: {} lr{} \t Loss Classification: {:.6f} Method {}\n'.format(
                args.source1, args.target, step, lr, loss.data, args.method)
        G.zero_grad()
        F1.zero_grad()
        zero_grad_all()
        if step % args.log_interval == 0:
            print("log_train")
            print(log_train)
        if step % args.save_interval == 0 and step > 0:
            print("testing")
            test(target_loader)
            #test(target_loader_unl)
            G.train()
            F1.train()
            if args.save_check:
                print('saving model')
                torch.save(
                    G.state_dict(),
                    os.path.join(
                        args.checkpath,
                        "G_iter_model_{}_{}_to_{}_step_{}.pth.tar".format(
                            args.method, args.source1, args.target, step)))
                torch.save(
                    F1.state_dict(),
                    os.path.join(
                        args.checkpath,
                        "F1_iter_model_{}_{}_to_{}_step_{}.pth.tar".format(
                            args.method, args.source1, args.target, step)))
예제 #13
0
def train():
    G.train()
    F1.train()
    F2.train()
    optimizer_g = optim.SGD(params,
                            lr=args.multi,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)
    optimizer_f1 = optim.SGD(list(F1.parameters()),
                             lr=1.0,
                             momentum=0.9,
                             weight_decay=0.0005,
                             nesterov=True)
    optimizer_f2 = optim.SGD(list(F2.parameters()),
                             lr=1.0,
                             momentum=0.9,
                             weight_decay=0.0005,
                             nesterov=True)

    def zero_grad_all():
        optimizer_g.zero_grad()
        optimizer_f1.zero_grad()
        optimizer_f2.zero_grad()

    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f1 = []
    for param_group in optimizer_f1.param_groups:
        param_lr_f1.append(param_group["lr"])
    param_lr_f2 = []
    for param_group in optimizer_f2.param_groups:
        param_lr_f2.append(param_group["lr"])
    criterion = nn.CrossEntropyLoss().cuda()
    all_step = args.steps
    data_iter_s = iter(source_loader)
    data_iter_t = iter(target_loader)
    data_iter_t_unl = iter(target_loader_unl)
    len_train_source = len(source_loader)
    len_train_target = len(target_loader)
    len_train_target_semi = len(target_loader_unl)
    best_acc = 0
    counter = 0
    for step in range(all_step):
        optimizer_g = inv_lr_scheduler(param_lr_g,
                                       optimizer_g,
                                       step,
                                       init_lr=args.lr)
        optimizer_f1 = inv_lr_scheduler(param_lr_f1,
                                        optimizer_f1,
                                        step,
                                        init_lr=args.lr)
        optimizer_f2 = inv_lr_scheduler(param_lr_f2,
                                        optimizer_f2,
                                        step,
                                        init_lr=args.lr)
        lr = optimizer_f1.param_groups[0]['lr']
        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)
        if step % len_train_target_semi == 0:
            data_iter_t_unl = iter(target_loader_unl)
        if step % len_train_source == 0:
            data_iter_s = iter(source_loader)
        data_t = next(data_iter_t)
        data_t_unl = next(data_iter_t_unl)
        data_s = next(data_iter_s)
        im_data_s.resize_(data_s[0].size()).copy_(data_s[0])
        gt_labels_s.resize_(data_s[1].size()).copy_(data_s[1])
        im_data_t.resize_(data_t[0].size()).copy_(data_t[0])
        gt_labels_t.resize_(data_t[1].size()).copy_(data_t[1])
        im_data_tu.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])
        zero_grad_all()
        data = torch.cat((im_data_s, im_data_t, im_data_tu), 0)
        # target = torch.cat((gt_labels_s, gt_labels_t), 0)
        output = G(data)
        output_s = output[:len(im_data_s)]
        output_t = output[len(im_data_s):len(im_data_s) + len(im_data_t)]
        output_tu = output[len(im_data_s) + len(im_data_t):]
        out_1t = F1(output_t)
        out_1s = F1(output_s)
        out_2t = F2(output_t)
        out_2s = F2(output_s)

        out_1tu = F1(output_tu)
        out_2tu = F2(output_tu)
        pseudo_label_1 = torch.softmax(out_1tu.detach_(), dim=-1)
        pseudo_label_2 = torch.softmax(out_2tu.detach_(), dim=-1)
        max_probs_1, targets_u_1 = torch.max(pseudo_label_1, dim=-1)
        max_probs_2, targets_u_2 = torch.max(pseudo_label_2, dim=-1)
        mask = (targets_u_1 == targets_u_2).float()

        ## Source-based Classifier loss: L1
        loss_1t = criterion(out_1t, gt_labels_t) + (F.cross_entropy(
            out_1tu, targets_u_1, reduction='none') * mask).mean()

        # mask = torch.cat((torch.ones_like(gt_labels_t).float(), mask), 0)
        # loss_1t = (F.cross_entropy(torch.cat((out_1t, out_1tu), 0),
        #                     torch.cat((gt_labels_t, targets_u_1), 0), reduction='none') * mask).mean()
        loss_1s = criterion(out_1s, gt_labels_s)

        loss_1 = args.alpha * loss_1s + (1 - args.alpha) * loss_1t

        ## Target-based Classifier loss
        loss_2t = criterion(out_2t, gt_labels_t) + (F.cross_entropy(
            out_2tu, targets_u_2, reduction='none') * mask).mean()
        # loss_2t = (F.cross_entropy(torch.cat((out_2t, out_2tu), 0),
        #                  torch.cat((gt_labels_t, targets_u_2), 0), reduction='none') * mask).mean()
        loss_2s = criterion(out_2s, gt_labels_s)

        loss_2 = args.alpha * loss_2t + (1 - args.alpha) * loss_2s

        loss_1.backward(retain_graph=True)
        loss_2.backward(retain_graph=True)
        optimizer_g.step()
        optimizer_f1.step()
        optimizer_f2.step()
        zero_grad_all()

        output = G(torch.cat((im_data_s, im_data_tu), 0))
        output_s = output[:len(im_data_s)]
        output_tu = output[len(im_data_s):]

        flag = 8
        if flag == 0:  #### original loss ===========
            entropy_s = adentropy(F1, output_s, args.beta)
            entropy_tu = -adentropy(F2, output_tu, args.lamda)
            entropy_s.backward(retain_graph=True)
            entropy_tu.backward(retain_graph=True)
            optimizer_f1.step()
            optimizer_f2.step()
            optimizer_g.step()

        if flag == 2:  #### ===== remove source entropy from Eq. 7 ===========
            entropy_s = adentropy(F1, output_s, args.beta)
            entropy_tu = -adentropy(F2, output_tu, args.lamda)
            entropy_s.backward(retain_graph=True)
            optimizer_g.step()
            zero_grad_all()
            entropy_tu.backward(retain_graph=True)
            optimizer_f2.step()
            optimizer_g.step()

        if flag == 3:  #### ===== remove target entropy from Eq. 8 ===========
            entropy_s = adentropy(F1, output_s, args.beta)
            entropy_tu = -adentropy(F2, output_tu, args.lamda)
            entropy_s.backward(retain_graph=True)
            optimizer_f1.step()
            optimizer_g.step()
            zero_grad_all()
            entropy_tu.backward(retain_graph=True)
            optimizer_g.step()

        if flag == 4:  #### ===== remove source entropy from Eq. 9 ===========
            entropy_s = adentropy(F1, output_s, args.beta)
            entropy_tu = -adentropy(F2, output_tu, args.lamda)
            entropy_s.backward(retain_graph=True)
            optimizer_f1.step()
            zero_grad_all()
            entropy_tu.backward(retain_graph=True)
            optimizer_f2.step()
            optimizer_g.step()

        if flag == 5:  #### ===== remove target entropy from Eq. 9 ===========
            entropy_s = adentropy(F1, output_s, args.beta)
            entropy_tu = -adentropy(F2, output_tu, args.lamda)
            entropy_s.backward(retain_graph=True)
            optimizer_f1.step()
            optimizer_g.step()
            zero_grad_all()
            entropy_tu.backward(retain_graph=True)
            optimizer_f2.step()
        elif flag == 6:
            # ===== No Gradient Reversal, minimize all entropies ===========
            # ===== Change Eq. 8 to: (1-alpha) * L_src + alpha * L_tar + lambda * H_tar
            # ===== Change Eq. 9 to: L_src + L_tar + beta * H_src + lambda * H_tar
            entropy_s = entropy(F1, output_s, args.beta)
            entropy_tu = entropy(F2, output_tu, args.lamda)
            entropy_s.backward(retain_graph=True)
            entropy_tu.backward(retain_graph=True)
            optimizer_f1.step()
            optimizer_f2.step()
            optimizer_g.step()

        elif flag == 7:
            # ===== Change Eq. 8 to: (1-alpha) * L_src + alpha * L_tar
            # ===== Change Eq. 9 to: L_src + L_tar + lambda * H_tar
            entropy_s = entropy(F1, output_s, args.beta)
            entropy_tu = entropy(F2, output_tu, args.lamda)
            entropy_s.backward(retain_graph=True)
            optimizer_f1.step()
            zero_grad_all()
            entropy_tu.backward(retain_graph=True)
            optimizer_g.step()

        elif flag == 8:
            # ===== Change Eq. 7 to: alpha * L_src + (1-alpha) * L_tar
            # ===== Change Eq. 9 to: L_src + L_tar + lambda * H_tar
            entropy_s = adentropy(F1, output_s, args.beta)
            entropy_tu = -adentropy(F2, output_tu, args.lamda)
            entropy_tu.backward(retain_graph=True)
            optimizer_f2.step()
            optimizer_g.step()


        log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                    'Loss_1 Classification: {:.6f} Loss_2 Classification: {:.6f} ' \
                    'Entropy_S {:.6f} Entropy_TU {:.6f}\n'.format(args.source, args.target,
                                         step, lr, loss_1.data, loss_2.data,
                                         entropy_s.data, entropy_tu.data)
        G.zero_grad()
        F1.zero_grad()
        F2.zero_grad()
        zero_grad_all()
        if step % args.log_interval == 0:
            print(log_train)
        if step % args.save_interval == 0 and step > 0:
            loss_test, acc_test = test(target_loader_test)
            loss_val, acc_val = test(target_loader_val)
            G.train()
            F1.train()
            F2.train()
            if acc_val >= best_acc:
                best_acc = acc_val
                best_acc_test = acc_test
                counter = 0
            else:
                counter += 1
            if args.early:
                if counter > args.patience:
                    break
            print('best acc test %f best acc val %f' %
                  (best_acc_test, best_acc))
            print('record %s' % record_file)
            with open(record_file, 'a') as f:
                f.write('step %d best %f final %f \n' %
                        (step, best_acc_test, best_acc))
            G.train()
            F1.train()
            F2.train()
            if args.save_check:
                print('saving model')
                torch.save(
                    G.state_dict(),
                    os.path.join(
                        args.checkpath, "G_iter_model_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.method,
                                                       args.source,
                                                       args.target, step)))
                torch.save(
                    F2.state_dict(),
                    os.path.join(
                        args.checkpath, "F2_iter_model_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.method,
                                                       args.source,
                                                       args.target, step)))
예제 #14
0
def train():
    G.train()
    F1.train()
    F2.train()
    optimizer_g = optim.SGD(params,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)
    optimizer_f1 = optim.SGD(list(F1.parameters()),
                             lr=1.0,
                             momentum=0.9,
                             weight_decay=0.0005,
                             nesterov=True)
    optimizer_f2 = optim.SGD(list(F2.parameters()),
                             lr=1.0,
                             momentum=0.9,
                             weight_decay=0.0005,
                             nesterov=True)

    # optimizer_g = optim.Adam(params)
    # optimizer_f1 = optim.Adam(list(F1.parameters()))
    # optimizer_f2 = optim.Adam(list(F2.parameters()))

    def zero_grad_all():
        optimizer_g.zero_grad()
        optimizer_f1.zero_grad()
        optimizer_f2.zero_grad()

    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f1 = []
    for param_group in optimizer_f1.param_groups:
        param_lr_f1.append(param_group["lr"])
    param_lr_f2 = []
    for param_group in optimizer_f2.param_groups:
        param_lr_f2.append(param_group["lr"])
    criterion = nn.CrossEntropyLoss().cuda()
    all_step = args.steps

    data_iter_t = iter(target_loader)
    data_iter_t_unl = iter(target_loader_unl)
    data_iter_s = iter(source_loader)
    len_train_source = len(source_loader)
    len_train_target = len(target_loader)
    len_train_target_semi = len(target_loader_unl)
    best_acc = 0
    counter = 0
    for step in range(all_step):
        optimizer_g = inv_lr_scheduler(param_lr_g,
                                       optimizer_g,
                                       step,
                                       init_lr=args.lr)
        optimizer_f1 = inv_lr_scheduler(param_lr_f1,
                                        optimizer_f1,
                                        step,
                                        init_lr=args.lr)
        optimizer_f2 = inv_lr_scheduler(param_lr_f2,
                                        optimizer_f2,
                                        step,
                                        init_lr=args.lr)
        lr = optimizer_f1.param_groups[0]['lr']
        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)
        if step % len_train_target_semi == 0:
            data_iter_t_unl = iter(target_loader_unl)
        if step % len_train_source == 0:
            data_iter_s = iter(source_loader)

        data_s = next(data_iter_s)
        data_t = next(data_iter_t)
        data_t_unl = next(data_iter_t_unl)

        # im_data_s.resize_(data_s[0].size()).copy_(data_s[0])
        # gt_labels_s.resize_(data_s[1].size()).copy_(data_s[1])
        # gt_labels_s.transpose_(1, 2)
        # gt_labels_s.resize_(gt_labels_s.shape[0] * gt_labels_s.shape[1], gt_labels_s.shape[2])
        # im_data_t.resize_(data_t[0].size()).copy_(data_t[0])
        # gt_labels_t.resize_(data_t[1].size()).copy_(data_t[1])
        # gt_labels_t.transpose_(1, 2)
        # gt_labels_t.resize_(gt_labels_t.shape[0] * gt_labels_t.shape[1], gt_labels_t.shape[2])
        # im_data_tu.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])
        # gt_labels_tu.resize_(data_t_unl[1].size()).copy_(data_t_unl[1])
        # gt_labels_tu.transpose_(1, 2)
        # gt_labels_tu.resize_(gt_labels_tu.shape[0] * gt_labels_tu.shape[1], gt_labels_tu.shape[2])
        im_data_s = data_s[0].cuda()
        im_data_s = im_data_s.reshape(-1, im_data_s.shape[2],
                                      im_data_s.shape[3], im_data_s.shape[4])
        im_data_s_strong = data_s[1].cuda()
        gt_labels_s = data_s[2].cuda()
        gt_labels_s = torch.transpose(gt_labels_s, 1, 2)
        gt_labels_s = gt_labels_s.reshape(
            gt_labels_s.shape[0] * gt_labels_s.shape[1], gt_labels_s.shape[2])
        im_data_t = data_t[0].cuda()
        im_data_t = im_data_t.reshape(-1, im_data_t.shape[2],
                                      im_data_t.shape[3], im_data_t.shape[4])
        im_data_t_strong = data_t[1].cuda()
        gt_labels_t = data_t[2].cuda()
        gt_labels_t = torch.transpose(gt_labels_t, 1, 2)
        gt_labels_t = gt_labels_t.reshape(
            gt_labels_t.shape[0] * gt_labels_t.shape[1], gt_labels_t.shape[2])
        im_data_tu = data_t_unl[0].cuda()
        im_data_tu = im_data_tu.reshape(-1, im_data_tu.shape[2],
                                        im_data_tu.shape[3],
                                        im_data_tu.shape[4])
        im_data_tu_strong = data_t_unl[1].cuda()
        gt_labels_tu = data_t_unl[2].cuda()
        gt_labels_tu = torch.transpose(gt_labels_tu, 1, 2)
        gt_labels_tu = gt_labels_tu.reshape(
            gt_labels_tu.shape[0] * gt_labels_tu.shape[1],
            gt_labels_tu.shape[2])
        zero_grad_all()
        data = torch.cat((im_data_s, im_data_t, im_data_tu), 0)
        # target = torch.cat((gt_labels_s, gt_labels_t), 0)
        output = G(data)
        output_s = output[:len(im_data_s)]
        output_t = output[len(im_data_s):len(im_data_s) + len(im_data_t)]
        output_tu = output[len(im_data_s) + len(im_data_t):]

        #### Supervised Loss for unrotated images
        output_s_no_rot = output_s.index_select(
            0,
            torch.arange(0, len(output_s), 4).cuda())
        output_t_no_rot = output_t.index_select(
            0,
            torch.arange(0, len(output_t), 4).cuda())
        gt_labels_s_cls = gt_labels_s[:, 0].index_select(
            0,
            torch.arange(0, len(output_s), 4).cuda())
        gt_labels_t_cls = gt_labels_t[:, 0].index_select(
            0,
            torch.arange(0, len(output_t), 4).cuda())
        logits_l_cls = F1(torch.cat((output_s_no_rot, output_t_no_rot), 0))
        target_l_cls = torch.cat((gt_labels_s_cls, gt_labels_t_cls), 0)

        loss_x = criterion(logits_l_cls, target_l_cls)

        ## flooding parameter
        b = 0.0
        loss_x = (loss_x - b).abs() + b

        ## Unsupervised Loss
        mixup = False
        if mixup:

            def mixup_data(x_tu, x_st, y_tu, y_st, alpha=1.0):
                '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
                if alpha > 0.:
                    lam = np.random.beta(alpha, alpha)
                else:
                    lam = 1.
                x_a = torch.cat((x_tu, x_st), 0)
                y_a = torch.cat((y_tu, y_st), 0)
                batch_size = x_st.size()[0]
                index = np.random.permutation(batch_size) + x_tu.size()[0]
                index = np.concatenate((np.arange(x_tu.size()[0]), index), 0)
                mixed_x = lam * x_a + (1 - lam) * x_a[index, :]
                # y_a, y_b = torch.Tensor(y).type(torch.LongTensor), torch.Tensor(y[index]).type(torch.LongTensor)
                y_a, y_b = y_a, y_a[index]
                mixed_x = mixed_x[x_tu.size()[0] // 2:]
                y_a = y_a[x_tu.size()[0] // 2:]
                y_b = y_b[x_tu.size()[0] // 2:]
                return mixed_x, y_a, y_b, lam

            output_tu_no_rot = output_tu.index_select(
                0,
                torch.arange(0, len(output_tu), 4).cuda())
            logits_tu_weak = F1(output_tu_no_rot)
            pseudo_label_tu = torch.softmax(logits_tu_weak.detach_(), dim=-1)
            max_probs, targets_tu = torch.max(pseudo_label_tu, dim=-1)

            mask = max_probs.ge(args.threshold).float().repeat(3)

            x_tu = torch.cat((im_data_tu_strong, im_data_tu_strong), 0)
            x_st = torch.cat((im_data_s_strong, im_data_t_strong), 0)
            y_tu = torch.cat((targets_tu, targets_tu), 0)
            y_st = torch.cat((gt_labels_s_cls, gt_labels_t_cls), 0)

            mixed_x_strong, y_a, y_b, lam = mixup_data(x_tu,
                                                       x_st,
                                                       y_tu,
                                                       y_st,
                                                       alpha=1.0)
            logits_mix_strong = F1(G(mixed_x_strong))
            loss_u = lam * (F.cross_entropy(logits_mix_strong, y_a, reduction='none') * mask).mean() + \
                     (1-lam) * (F.cross_entropy(logits_mix_strong, y_b, reduction='none') * mask).mean()

        else:
            output_tu_strong = G(im_data_tu_strong)
            output_tu_no_rot = output_tu.index_select(
                0,
                torch.arange(0, len(output_tu), 4).cuda())
            logits_tu_weak = F1(output_tu_no_rot)
            logits_tu_strong = F1(output_tu_strong)
            pseudo_label_tu = torch.softmax(logits_tu_weak.detach_(), dim=-1)
            max_probs, targets_tu = torch.max(pseudo_label_tu, dim=-1)
            mask = max_probs.ge(args.threshold).float()
            loss_u = (F.cross_entropy(
                logits_tu_strong, targets_tu, reduction='none') * mask).mean()

        ### Rotation Self-supervised Loss
        logits_ul_rot = F2(torch.cat((output_s, output_t, output_tu), 0))
        target_ul_rot = torch.cat(
            (gt_labels_s[:, 1], gt_labels_t[:, 1], gt_labels_tu[:, 1]), 0)
        loss_rot = criterion(logits_ul_rot, target_ul_rot.cuda())

        ### Overall Loss
        loss = loss_x + loss_u + 0.7 * loss_rot

        loss.backward(retain_graph=True)
        optimizer_g.step()
        optimizer_f1.step()
        optimizer_f2.step()
        zero_grad_all()


        log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                    'loss: {:.6f} loss_x: {:.6f} ' \
                    'loss_u {:.6f} loss_rot: {:.6f}\n'.format(args.source, args.target,
                                         step, lr, loss.data, loss_x.data,
                                         loss_u.data, loss_rot.data)
        G.zero_grad()
        F1.zero_grad()
        F2.zero_grad()
        zero_grad_all()
        if step % args.log_interval == 0:
            print(log_train)
        if step % args.save_interval == 0 and step > 0:
            loss_test, acc_test = test(target_loader_test)
            loss_val, acc_val = test(target_loader_val)
            G.train()
            F1.train()
            F2.train()
            if acc_val >= best_acc:
                best_acc = acc_val
                best_acc_test = acc_test
                counter = 0
            else:
                counter += 1
            if args.early:
                if counter > args.patience:
                    break
            print('best acc test %f best acc val %f' %
                  (best_acc_test, best_acc))
            print('record %s' % record_file)
            with open(record_file, 'a') as f:
                f.write('step %d best %f final %f \n' %
                        (step, best_acc_test, best_acc))
            G.train()
            F1.train()
            F2.train()
            if args.save_check:
                print('saving model')
                torch.save(
                    G.state_dict(),
                    os.path.join(
                        args.checkpath, "G_iter_model_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.method,
                                                       args.source,
                                                       args.target, step)))
                torch.save(
                    F2.state_dict(),
                    os.path.join(
                        args.checkpath, "F2_iter_model_{}_{}_"
                        "to_{}_step_{}.pth.tar".format(args.method,
                                                       args.source,
                                                       args.target, step)))
예제 #15
0
    def train():
        G.train()
        F1.train()
        optimizer_g = optim.SGD(params,
                                momentum=0.9,
                                weight_decay=0.0005,
                                nesterov=True)
        optimizer_f = optim.SGD(list(F1.parameters()),
                                lr=1.0,
                                momentum=0.9,
                                weight_decay=0.0005,
                                nesterov=True)

        # Loading the states of the two optmizers
        optimizer_g.load_state_dict(main_dict['optimizer_g'])
        optimizer_f.load_state_dict(main_dict['optimizer_f'])
        print("Loaded optimizer states")

        def zero_grad_all():
            optimizer_g.zero_grad()
            optimizer_f.zero_grad()

        param_lr_g = []
        for param_group in optimizer_g.param_groups:
            param_lr_g.append(param_group["lr"])
        param_lr_f = []
        for param_group in optimizer_f.param_groups:
            param_lr_f.append(param_group["lr"])

        # Setting the loss function to be used for the classification loss
        if args.loss == 'CE':
            criterion = nn.CrossEntropyLoss().to(device)
        if args.loss == 'FL':
            criterion = FocalLoss(alpha=1, gamma=1).to(device)
        if args.loss == 'CBFL':
            # Calculating the list having the number of examples per class which is going to be used in the CB focal loss
            beta = 0.99
            effective_num = 1.0 - np.power(beta, class_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                class_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)
            criterion = CBFocalLoss(weight=per_cls_weights,
                                    gamma=0.5).to(device)

        all_step = args.steps
        data_iter_s = iter(source_loader)
        data_iter_t = iter(target_loader)
        data_iter_t_unl = iter(target_loader_unl)
        len_train_source = len(source_loader)
        len_train_target = len(target_loader)
        len_train_target_semi = len(target_loader_unl)
        best_acc = 0
        counter = 0
        for step in range(all_step):
            optimizer_g = inv_lr_scheduler(param_lr_g,
                                           optimizer_g,
                                           step,
                                           init_lr=args.lr)
            optimizer_f = inv_lr_scheduler(param_lr_f,
                                           optimizer_f,
                                           step,
                                           init_lr=args.lr)
            lr = optimizer_f.param_groups[0]['lr']
            # condition for restarting the iteration for each of the data loaders
            if step % len_train_target == 0:
                data_iter_t = iter(target_loader)
            if step % len_train_target_semi == 0:
                data_iter_t_unl = iter(target_loader_unl)
            if step % len_train_source == 0:
                data_iter_s = iter(source_loader)
            data_t = next(data_iter_t)
            data_t_unl = next(data_iter_t_unl)
            data_s = next(data_iter_s)
            with torch.no_grad():
                im_data_s.resize_(data_s[0].size()).copy_(data_s[0])
                gt_labels_s.resize_(data_s[1].size()).copy_(data_s[1])
                im_data_t.resize_(data_t[0].size()).copy_(data_t[0])
                gt_labels_t.resize_(data_t[1].size()).copy_(data_t[1])
                im_data_tu.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])

            zero_grad_all()
            data = torch.cat((im_data_s, im_data_t), 0)
            target = torch.cat((gt_labels_s, gt_labels_t), 0)
            output = G(data)
            out1 = F1(output)
            loss = criterion(out1, target)
            loss.backward(retain_graph=True)
            optimizer_g.step()
            optimizer_f.step()
            zero_grad_all()
            # list of the weights and image paths in this batch
            img_paths = list(data_t_unl[2])
            df1 = df.loc[df['img'].isin(img_paths)]
            df1 = df1['weight']
            weight_list = list(df1)

            if not args.method == 'S+T':
                output = G(im_data_tu)
                if args.method == 'ENT':
                    loss_t = entropy(F1, output, args.lamda)
                    loss_t.backward()
                    optimizer_f.step()
                    optimizer_g.step()
                elif args.method == 'MME':
                    loss_t = adentropy(F1, output, args.lamda, weight_list)
                    loss_t.backward()
                    optimizer_f.step()
                    optimizer_g.step()
                else:
                    raise ValueError('Method cannot be recognized.')
                log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                            'Loss Classification: {:.6f} Loss T {:.6f} ' \
                            'Method {}\n'.format(args.source, args.target,
                                                step, lr, loss.data,
                                                -loss_t.data, args.method)
            else:
                log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                            'Loss Classification: {:.6f} Method {}\n'.\
                    format(args.source, args.target,
                        step, lr, loss.data,
                        args.method)
            G.zero_grad()
            F1.zero_grad()
            zero_grad_all()
            if step % args.log_interval == 0:
                print(log_train)
            if step % args.save_interval == 0 and step > 0:
                loss_val, acc_val = test(target_loader_val)
                loss_test, acc_test = test(target_loader_test)
                G.train()
                F1.train()
                if acc_test >= best_acc:
                    best_acc = acc_test
                    best_acc_test = acc_test
                    counter = 0
                else:
                    counter += 1
                if args.early:
                    if counter > args.patience:
                        break
                print('best acc test %f best acc val %f' %
                      (best_acc_test, acc_val))
                print('record %s' % record_file)
                with open(record_file, 'a') as f:
                    f.write('step %d best %f final %f \n' %
                            (step, best_acc_test, acc_val))
                G.train()
                F1.train()
                #saving model as a checkpoint dict having many things
                if args.save_check:
                    print('saving model')
                    is_best = True if counter == 0 else False
                    save_mymodel(
                        args, {
                            'step': step,
                            'arch': args.net,
                            'G_state_dict': G.state_dict(),
                            'F1_state_dict': F1.state_dict(),
                            'best_acc_test': best_acc_test,
                            'optimizer_g': optimizer_g.state_dict(),
                            'optimizer_f': optimizer_f.state_dict(),
                        }, is_best)
예제 #16
0
def train():
    G.train()
    F1.train()
    optimizer_g = optim.SGD(params, momentum=0.9,
                            weight_decay=0.0005, nesterov=True)
    optimizer_f = optim.SGD(list(F1.parameters()), lr=1.0, momentum=0.9,
                            weight_decay=0.0005, nesterov=True)

    def zero_grad_all():
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()
    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f = []
    for param_group in optimizer_f.param_groups:
        param_lr_f.append(param_group["lr"])
    criterion = nn.CrossEntropyLoss().cuda()
    all_step = args.steps
    data_iter_s = iter(source_loader)
    data_iter_t = iter(target_loader)
    data_iter_t_unl = iter(target_loader_unl)
    len_train_source = len(source_loader)
    len_train_target = len(target_loader)
    len_train_target_semi = len(target_loader_unl)
    best_acc = 0
    counter = 0
    for step in range(all_step):
        optimizer_g = inv_lr_scheduler(param_lr_g, optimizer_g, step,
                                       init_lr=args.lr)
        optimizer_f = inv_lr_scheduler(param_lr_f, optimizer_f, step,
                                       init_lr=args.lr)
        lr = optimizer_f.param_groups[0]['lr']
        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)
        if step % len_train_target_semi == 0:
            data_iter_t_unl = iter(target_loader_unl)
        if step % len_train_source == 0:
            data_iter_s = iter(source_loader)
        data_t = next(data_iter_t)
        data_t_unl = next(data_iter_t_unl)
        data_s = next(data_iter_s)
        im_data_s.data.resize_(data_s[0].size()).copy_(data_s[0])
        gt_labels_s.data.resize_(data_s[1].size()).copy_(data_s[1])
        im_data_t.data.resize_(data_t[0].size()).copy_(data_t[0])
        gt_labels_t.data.resize_(data_t[1].size()).copy_(data_t[1])
        im_data_tu.data.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])
        zero_grad_all()
        data = torch.cat((im_data_s, im_data_t), 0)
        target = torch.cat((gt_labels_s, gt_labels_t), 0)
        output = G(data)
        out1 = F1(output)
        loss = criterion(out1, target)
        loss.backward(retain_graph=True)
        optimizer_g.step()
        optimizer_f.step()
        zero_grad_all()
        if not args.method == 'S+T':
            output = G(im_data_tu)
            if args.method == 'ENT':
                loss_t = entropy(F1, output, args.lamda)
                loss_t.backward()
                optimizer_f.step()
                optimizer_g.step()
            elif args.method == 'MME':
                loss_t = adentropy(F1, output, args.lamda)
                loss_t.backward()
                optimizer_f.step()
                optimizer_g.step()
            else:
                raise ValueError('Method cannot be recognized.')
            log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                        'Loss Classification: {:.6f} Loss T {:.6f} ' \
                        'Method {}\n'.format(args.source, args.target,
                                             step, lr, loss.data,
                                             -loss_t.data, args.method)
        else:
            log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                        'Loss Classification: {:.6f} Method {}\n'.\
                format(args.source, args.target,
                       step, lr, loss.data,
                       args.method)
        G.zero_grad()
        F1.zero_grad()
        zero_grad_all()
        if step % args.log_interval == 0:
            print(log_train)
        if step % args.save_interval == 0 and step > 0:
            loss_test, acc_test = test(target_loader_test)
            loss_val, acc_val = test(target_loader_val)
            G.train()
            F1.train()
            if acc_val >= best_acc:
                best_acc = acc_val
                best_acc_test = acc_test
                counter = 0
            else:
                counter += 1
            if args.early:
                if counter > args.patience:
                    break
            print('best acc test %f best acc val %f' % (best_acc_test,
                                                        acc_val))
            print('record %s' % record_file)
            with open(record_file, 'a') as f:
                f.write('step %d best %f final %f \n' % (step,
                                                         best_acc_test,
                                                         acc_val))
            G.train()
            F1.train()
            if args.save_check:
                print('saving model')
                torch.save(G.state_dict(),
                           os.path.join(args.checkpath,
                                        "G_iter_model_{}_{}_"
                                        "to_{}_step_{}.pth.tar".
                                        format(args.method, args.source,
                                               args.target, step)))
                torch.save(F1.state_dict(),
                           os.path.join(args.checkpath,
                                        "F1_iter_model_{}_{}_"
                                        "to_{}_step_{}.pth.tar".
                                        format(args.method, args.source,
                                               args.target, step)))
예제 #17
0
def train():
    G.train()
    F1.train()
    optimizer_g = optim.SGD(params,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)
    optimizer_f = optim.SGD(list(F1.parameters()),
                            lr=1.0,
                            momentum=0.9,
                            weight_decay=0.0005,
                            nesterov=True)

    def zero_grad_all():
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()

    param_lr_g = []
    for param_group in optimizer_g.param_groups:
        param_lr_g.append(param_group["lr"])
    param_lr_f = []
    for param_group in optimizer_f.param_groups:
        param_lr_f.append(param_group["lr"])

    criterion = cb_focal_loss(class_num_list, beta=args.beta, gamma=args.gamma)
    all_step = args.steps
    data_iter_s = iter(source_loader)
    data_iter_t = iter(target_loader)
    data_iter_t_unl = iter(target_loader_unl)
    len_train_source = len(source_loader)
    len_train_target = len(target_loader)
    len_train_target_semi = len(target_loader_unl)
    counter = 0
    print("=> loading checkpoint...")
    #filename = 'freezed_models/%s_%s_%s_%s_%s.ckpt.best.pth.tar' % (args.net,args.method, args.source, args.target,args.num)
    filename = "freezed_models/ent1p2r.ckpt.best.pth.tar"
    main_dict = torch.load(filename)
    best_acc_test = main_dict['best_acc_test']
    best_acc = 0
    G.load_state_dict(main_dict['G_state_dict'])
    F1.load_state_dict(main_dict['F1_state_dict'])
    optimizer_g.load_state_dict(main_dict['optimizer_g'])
    optimizer_f.load_state_dict(main_dict['optimizer_f'])
    print("=> loaded checkpoint...")
    print("=> inferencing from checkpoint...")
    _, _, paths_to_weights = eval_inference(G, F1, class_list, class_num_list,
                                            args, 0)
    print("=> loaded weight file...")
    for step in range(main_dict['step'], all_step):
        optimizer_g = inv_lr_scheduler(param_lr_g,
                                       optimizer_g,
                                       step,
                                       init_lr=args.lr)
        optimizer_f = inv_lr_scheduler(param_lr_f,
                                       optimizer_f,
                                       step,
                                       init_lr=args.lr)
        lr = optimizer_f.param_groups[0]['lr']
        if step % len_train_target == 0:
            data_iter_t = iter(target_loader)
        if step % len_train_target_semi == 0:
            data_iter_t_unl = iter(target_loader_unl)
        if step % len_train_source == 0:
            data_iter_s = iter(source_loader)
        data_t = next(data_iter_t)
        data_t_unl = next(data_iter_t_unl)
        data_s = next(data_iter_s)
        im_data_s.data.resize_(data_s[0].size()).copy_(data_s[0])
        gt_labels_s.data.resize_(data_s[1].size()).copy_(data_s[1])
        im_data_t.data.resize_(data_t[0].size()).copy_(data_t[0])
        gt_labels_t.data.resize_(data_t[1].size()).copy_(data_t[1])
        im_data_tu.data.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])
        paths = data_t_unl[2]
        zero_grad_all()
        data = torch.cat((im_data_s, im_data_t), 0)
        target = torch.cat((gt_labels_s, gt_labels_t), 0)
        output = G(data)
        out1 = F1(output)

        if args.net == 'resnet34':
            reg_loss = regularizer(F1.fc3.weight, att)
            #reg_loss = 0
        else:
            reg_loss = regularizer(F1.fc2.weight, att)
            #reg_loss = 0

        loss = criterion(out1, target) + 0 * reg_loss
        loss.backward(retain_graph=True)
        optimizer_g.step()
        optimizer_f.step()
        zero_grad_all()
        if not args.method == 'S+T':
            output = G(im_data_tu)
            if args.method == 'ENT':
                #loss_t = entropy(F1, output, args.lamda)
                loss_t = weighted_entropy(F1, output, args.lamda, paths,
                                          paths_to_weights)
                loss_t.backward()
                optimizer_f.step()
                optimizer_g.step()
            elif args.method == 'MME':
                loss_t = weighted_adentropy(F1, output, args.lamda, paths,
                                            paths_to_weights)
                loss_t.backward()
                optimizer_f.step()
                optimizer_g.step()
            else:
                raise ValueError('Method cannot be recognized.')
            log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                        'Loss Classification: {:.6f} Loss T {:.6f} Reg {:.6f} ' \
                        'Method {}\n'.format(args.source, args.target,
                                             step, lr, loss.data, reg_loss.data,
                                             -loss_t.data, args.method)
        else:
            log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                        'Loss Classification: {:.6f} Method {}\n'.\
                format(args.source, args.target,
                       step, lr, loss.data,
                       args.method)
        G.zero_grad()
        F1.zero_grad()
        zero_grad_all()
        if step % args.log_interval == 0:
            print(log_train)
        if step % args.save_interval == 0 and step > 0:
            print("Re-weighting for entropy...")
            loss_test, acc_test, paths_to_weights = eval_inference(
                G, F1, class_list, class_num_list, args, step)
            #loss_test, acc_test = test(target_loader_test)
            loss_val, acc_val = test(target_loader_val)
            #loss_unseen, acc_unseen  = test(target_loader_unseen)
            G.train()
            F1.train()

            if acc_val > best_acc:
                best_acc = acc_val
                best_acc_test = acc_test
                counter = 0
            else:
                counter += 1
            if args.early:
                if counter > args.patience:
                    break
            print('best acc test %f best acc val %f' %
                  (best_acc_test, best_acc))
            print('record %s' % record_file)
            with open(record_file, 'a') as f:
                f.write('step %d best %f final %f \n' %
                        (step, best_acc_test, acc_val))
            G.train()
            F1.train()
            if args.save_check:
                print('saving model...')
                #is_best = True
                #save_mymodel(args, {
                #     'step': step,
                #     'arch': args.net,
                #     'G_state_dict': G.state_dict(),
                #     'F1_state_dict': F1.state_dict(),
                #     'best_acc_test': best_acc_test,
                #     'optimizer_g' : optimizer_g.state_dict(),
                #     'optimizer_f' : optimizer_f.state_dict(),
                #     }, is_best, None)

                torch.save(
                    {
                        'step': step,
                        'arch': args.net,
                        'G_state_dict': G.state_dict(),
                        'F1_state_dict': F1.state_dict(),
                        'best_acc_test': best_acc_test,
                        'optimizer_g': optimizer_g.state_dict(),
                        'optimizer_f': optimizer_f.state_dict(),
                    }, '%s/%s_%s_%s_%s_%s_%s.ckpt.pth.tar' %
                    (args.checkpath, args.net, args.method, args.source,
                     args.target, args.num, str(step)))