コード例 #1
0
ファイル: demo_ssda_mixmatch.py プロジェクト: tim-learn/ATDOC
def train(args):
    ## set pre-process
    dset_loaders = data_load(args)

    max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"]))
    args.max_iter = args.max_epoch * max_len

    ## set base network
    if args.net == 'resnet34':
        netG = utils.ResBase34().cuda()
    elif args.net == 'vgg16':
        netG = utils.VGG16Base().cuda()

    netF = utils.ResClassifier(class_num=args.class_num,
                               feature_dim=netG.in_features,
                               bottleneck_dim=args.bottleneck_dim).cuda()

    if len(args.gpu_id.split(',')) > 1:
        netG = nn.DataParallel(netG)

    optimizer_g = optim.SGD(netG.parameters(), lr=args.lr * 0.1)
    optimizer_f = optim.SGD(netF.parameters(), lr=args.lr)

    base_network = nn.Sequential(netG, netF)
    source_loader_iter = iter(dset_loaders["source"])
    target_loader_iter = iter(dset_loaders["target"])
    ltarget_loader_iter = iter(dset_loaders["ltarget"])

    if args.pl.startswith('atdoc_na'):
        mem_fea = torch.rand(
            len(dset_loaders["target"].dataset) +
            len(dset_loaders["ltarget"].dataset), args.bottleneck_dim).cuda()
        mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True)
        mem_cls = torch.ones(
            len(dset_loaders["target"].dataset) +
            len(dset_loaders["ltarget"].dataset),
            args.class_num).cuda() / args.class_num

    if args.pl == 'atdoc_nc':
        mem_fea = torch.rand(args.class_num, args.bottleneck_dim).cuda()
        mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True)

    list_acc = []
    best_val_acc = 0

    for iter_num in range(1, args.max_iter + 1):
        base_network.train()
        lr_scheduler(optimizer_g,
                     init_lr=args.lr * 0.1,
                     iter_num=iter_num,
                     max_iter=args.max_iter)
        lr_scheduler(optimizer_f,
                     init_lr=args.lr,
                     iter_num=iter_num,
                     max_iter=args.max_iter)

        try:
            inputs_source, labels_source = source_loader_iter.next()
        except:
            source_loader_iter = iter(dset_loaders["source"])
            inputs_source, labels_source = source_loader_iter.next()
        try:
            inputs_target, _, target_idx = target_loader_iter.next()
        except:
            target_loader_iter = iter(dset_loaders["target"])
            inputs_target, _, target_idx = target_loader_iter.next()

        try:
            inputs_ltarget, labels_ltarget, lidx = ltarget_loader_iter.next()
        except:
            ltarget_loader_iter = iter(dset_loaders["ltarget"])
            inputs_ltarget, labels_ltarget, lidx = ltarget_loader_iter.next()

        inputs_lt = inputs_ltarget[0].cuda()
        inputs_lt2 = inputs_ltarget[1].cuda()
        targets_lt = torch.zeros(args.batch_size // 3,
                                 args.class_num).scatter_(
                                     1, labels_ltarget.view(-1, 1), 1)
        targets_lt = targets_lt.cuda()

        targets_s = torch.zeros(args.batch_size, args.class_num).scatter_(
            1, labels_source.view(-1, 1), 1)
        inputs_s = inputs_source.cuda()
        targets_s = targets_s.cuda()
        inputs_t = inputs_target[0].cuda()
        inputs_t2 = inputs_target[1].cuda()

        if args.pl.startswith('atdoc_na'):

            targets_u = 0
            for inp in [inputs_t, inputs_t2]:
                with torch.no_grad():
                    features_target, outputs_u = base_network(inp)

                dis = -torch.mm(features_target.detach(), mem_fea.t())
                for di in range(dis.size(0)):
                    dis[di, target_idx[di]] = torch.max(dis)
                    # dis[di, target_idx[di]+len(dset_loaders["target"].dataset)] = torch.max(dis)

                _, p1 = torch.sort(dis, dim=1)
                w = torch.zeros(features_target.size(0),
                                mem_fea.size(0)).cuda()
                for wi in range(w.size(0)):
                    for wj in range(args.K):
                        w[wi][p1[wi, wj]] = 1 / args.K

                _, pred = torch.max(w.mm(mem_cls), 1)

                targets_u += 0.5 * torch.eye(outputs_u.size(1))[pred].cuda()

        elif args.pl == 'atdoc_nc':

            targets_u = 0
            mem_fea_norm = mem_fea / torch.norm(
                mem_fea, p=2, dim=1, keepdim=True)
            for inp in [inputs_t, inputs_t2]:
                with torch.no_grad():
                    features_target, outputs_u = base_network(inp)
                dis = torch.mm(features_target.detach(), mem_fea_norm.t())
                _, pred = torch.max(dis, dim=1)
                targets_u += 0.5 * torch.eye(outputs_u.size(1))[pred].cuda()

        elif args.pl == 'npl':

            targets_u = 0
            for inp in [inputs_t, inputs_t2]:
                with torch.no_grad():
                    _, outputs_u = base_network(inp)
                _, pred = torch.max(outputs_u.detach(), 1)
                targets_u += 0.5 * torch.eye(outputs_u.size(1))[pred].cuda()

        else:
            with torch.no_grad():
                # compute guessed labels of unlabel samples
                _, outputs_u = base_network(inputs_t)
                _, outputs_u2 = base_network(inputs_t2)
                p = (torch.softmax(outputs_u, dim=1) +
                     torch.softmax(outputs_u2, dim=1)) / 2
                pt = p**(1 / args.T)
                targets_u = pt / pt.sum(dim=1, keepdim=True)
                targets_u = targets_u.detach()

        ####################################################################
        all_inputs = torch.cat(
            [inputs_s, inputs_lt, inputs_t, inputs_lt2, inputs_t2], dim=0)
        all_targets = torch.cat(
            [targets_s, targets_lt, targets_u, targets_lt, targets_u], dim=0)
        if args.alpha > 0:
            l = np.random.beta(args.alpha, args.alpha)
            l = max(l, 1 - l)
        else:
            l = 1
        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]
        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation
        mixed_input = list(torch.split(mixed_input, args.batch_size))
        mixed_input = utils.interleave(mixed_input, args.batch_size)
        # s = [sa, sb, sc]
        # t1 = [t1a, t1b, t1c]
        # t2 = [t2a, t2b, t2c]
        # => s' = [sa, t1b, t2c]   t1' = [t1a, sb, t1c]   t2' = [t2a, t2b, sc]

        # _, logits = base_network(mixed_input[0])
        features, logits = base_network(mixed_input[0])
        logits = [logits]
        for input in mixed_input[1:]:
            _, temp = base_network(input)
            logits.append(temp)

        # put interleaved samples back
        # [i[:,0] for i in aa]
        logits = utils.interleave(logits, args.batch_size)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)

        train_criterion = utils.SemiLoss()

        Lx, Lu, w = train_criterion(logits_x, mixed_target[:args.batch_size],
                                    logits_u, mixed_target[args.batch_size:],
                                    iter_num, args.max_iter, args.lambda_u)
        loss = Lx + w * Lu

        optimizer_g.zero_grad()
        optimizer_f.zero_grad()
        loss.backward()
        optimizer_g.step()
        optimizer_f.step()

        if args.pl.startswith('atdoc_na'):
            base_network.eval()
            with torch.no_grad():
                fea1, outputs1 = base_network(inputs_t)
                fea2, outputs2 = base_network(inputs_t2)
                feat = 0.5 * (fea1 + fea2)
                feat = feat / torch.norm(feat, p=2, dim=1, keepdim=True)
                softmax_out = 0.5 * (nn.Softmax(dim=1)(outputs1) +
                                     nn.Softmax(dim=1)(outputs2))
                softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0))

            mem_fea[target_idx] = (
                1.0 -
                args.momentum) * mem_fea[target_idx] + args.momentum * feat
            mem_cls[target_idx] = (1.0 - args.momentum) * mem_cls[
                target_idx] + args.momentum * softmax_out

            with torch.no_grad():
                fea1, outputs1 = base_network(inputs_lt)
                fea2, outputs2 = base_network(inputs_lt2)
                feat = 0.5 * (fea1 + fea2)
                feat = feat / torch.norm(feat, p=2, dim=1, keepdim=True)
                softmax_out = 0.5 * (nn.Softmax(dim=1)(outputs1) +
                                     nn.Softmax(dim=1)(outputs2))
                softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0))

            mem_fea[lidx + len(dset_loaders["target"].dataset)] = (1.0 - args.momentum) * \
                mem_fea[lidx + len(dset_loaders["target"].dataset)] + args.momentum*feat
            mem_cls[lidx + len(dset_loaders["target"].dataset)] = (1.0 - args.momentum) * \
                mem_cls[lidx + len(dset_loaders["target"].dataset)] + args.momentum*softmax_out

        if args.pl == 'atdoc_nc':
            base_network.eval()
            with torch.no_grad():
                fea1, outputs1 = base_network(inputs_t)
                fea2, outputs2 = base_network(inputs_t2)
                feat_u = 0.5 * (fea1 + fea2)
                softmax_t = 0.5 * (nn.Softmax(dim=1)(outputs1) +
                                   nn.Softmax(dim=1)(outputs2))
                _, pred_t = torch.max(softmax_t, 1)
                onehot_tu = torch.eye(args.class_num)[pred_t].cuda()

            with torch.no_grad():
                fea1, outputs1 = base_network(inputs_lt)
                fea2, outputs2 = base_network(inputs_lt2)
                feat_l = 0.5 * (fea1 + fea2)
                softmax_t = 0.5 * (nn.Softmax(dim=1)(outputs1) +
                                   nn.Softmax(dim=1)(outputs2))
                _, pred_t = torch.max(softmax_t, 1)
                onehot_tl = torch.eye(args.class_num)[pred_t].cuda()
                # onehot_tl = torch.eye(args.class_num)[labels_ltarget].cuda()

            center_t = ((torch.mm(feat_u.t(), onehot_tu) + torch.mm(
                feat_l.t(), onehot_tl))) / (onehot_tu.sum(dim=0) +
                                            onehot_tl.sum(dim=0) + 1e-8)
            mem_fea = (1.0 - args.momentum
                       ) * mem_fea + args.momentum * center_t.t().clone()

        if iter_num % int(args.eval_epoch * max_len) == 0:
            base_network.eval()
            if args.dset == 'VISDA-C':
                acc, py, score, y, tacc = utils.cal_acc_visda(
                    dset_loaders["test"], base_network)
                args.out_file.write(tacc + '\n')
                args.out_file.flush()
            else:
                acc, py, score, y = utils.cal_acc(dset_loaders["test"],
                                                  base_network)
                val_acc, _, _, _ = utils.cal_acc(dset_loaders["val"],
                                                 base_network)

            list_acc.append(acc * 100)
            if best_val_acc <= val_acc:
                best_val_acc = val_acc
                best_acc = acc
                best_y = y
                best_py = py
                best_score = score

            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Val Acc = {:.2f}%'.format(
                args.name, iter_num, args.max_iter, acc * 100, val_acc * 100)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

    val_acc = best_acc * 100
    idx = np.argmax(np.array(list_acc))
    max_acc = list_acc[idx]
    final_acc = list_acc[-1]

    log_str = '\n==========================================\n'
    log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(
        val_acc, max_acc, final_acc)
    args.out_file.write(log_str + '\n')
    args.out_file.flush()

    # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt"))
    # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(),
    #     'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()})

    return base_network, py
コード例 #2
0
ファイル: digit_mixmatch.py プロジェクト: tim-learn/SHOT-plus
def train(args, txt_src, txt_tgt):
    ## set pre-process
    dset_loaders = data_load(args, txt_src, txt_tgt)
    # pdb.set_trace()
    max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"]))
    max_iter = args.max_epoch * max_len
    interval_iter = max_iter // 10

    if args.dset == 'u2m':
        netG = network.LeNetBase().cuda()
    elif args.dset == 'm2u':
        netG = network.LeNetBase().cuda()
    elif args.dset == 's2m':
        netG = network.DTNBase().cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netG.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    if args.model == 'source':
        modelpath = args.output_dir + "/source_F.pt"
        netG.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/source_B.pt"
        netB.load_state_dict(torch.load(modelpath))
    else:
        modelpath = args.output_dir + "/target_F_" + args.savename + ".pt"
        netG.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/target_B_" + args.savename + ".pt"
        netB.load_state_dict(torch.load(modelpath))

    netF = nn.Sequential(netB, netC)
    optimizer_g = optim.SGD(netG.parameters(), lr=args.lr * 0.1)
    optimizer_f = optim.SGD(netF.parameters(), lr=args.lr)

    base_network = nn.Sequential(netG, netF)
    source_loader_iter = iter(dset_loaders["source"])
    target_loader_iter = iter(dset_loaders["target"])

    list_acc = []
    best_ent = 100

    for iter_num in range(1, max_iter + 1):
        base_network.train()
        lr_scheduler(optimizer_g,
                     init_lr=args.lr * 0.1,
                     iter_num=iter_num,
                     max_iter=max_iter)
        lr_scheduler(optimizer_f,
                     init_lr=args.lr,
                     iter_num=iter_num,
                     max_iter=max_iter)

        try:
            inputs_source, labels_source = source_loader_iter.next()
        except:
            source_loader_iter = iter(dset_loaders["source"])
            inputs_source, labels_source = source_loader_iter.next()
        try:
            inputs_target, _, target_idx = target_loader_iter.next()
        except:
            target_loader_iter = iter(dset_loaders["target"])
            inputs_target, _, target_idx = target_loader_iter.next()

        targets_s = torch.zeros(args.batch_size, args.class_num).scatter_(
            1, labels_source.view(-1, 1), 1)
        inputs_s = inputs_source.cuda()
        targets_s = targets_s.cuda()
        inputs_t = inputs_target[0].cuda()
        inputs_t2 = inputs_target[1].cuda()

        with torch.no_grad():
            # compute guessed labels of unlabel samples
            outputs_u = base_network(inputs_t)
            outputs_u2 = base_network(inputs_t2)
            p = (torch.softmax(outputs_u, dim=1) +
                 torch.softmax(outputs_u2, dim=1)) / 2
            pt = p**(1 / args.T)
            targets_u = pt / pt.sum(dim=1, keepdim=True)
            targets_u = targets_u.detach()

        ####################################################################
        all_inputs = torch.cat([inputs_s, inputs_t, inputs_t2], dim=0)
        all_targets = torch.cat([targets_s, targets_u, targets_u], dim=0)
        if args.alpha > 0:
            l = np.random.beta(args.alpha, args.alpha)
            l = max(l, 1 - l)
        else:
            l = 1
        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]
        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation
        mixed_input = list(torch.split(mixed_input, args.batch_size))
        mixed_input = utils.interleave(mixed_input, args.batch_size)
        # s = [sa, sb, sc]
        # t1 = [t1a, t1b, t1c]
        # t2 = [t2a, t2b, t2c]
        # => s' = [sa, t1b, t2c]   t1' = [t1a, sb, t1c]   t2' = [t2a, t2b, sc]

        logits = base_network(mixed_input[0])
        logits = [logits]
        for input in mixed_input[1:]:
            temp = base_network(input)
            logits.append(temp)

        # put interleaved samples back
        # [i[:,0] for i in aa]
        logits = utils.interleave(logits, args.batch_size)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)

        train_criterion = utils.SemiLoss()

        Lx, Lu, w = train_criterion(logits_x, mixed_target[:args.batch_size],
                                    logits_u, mixed_target[args.batch_size:],
                                    iter_num, max_iter, args.lambda_u)
        loss = Lx + w * Lu

        optimizer_g.zero_grad()
        optimizer_f.zero_grad()
        loss.backward()
        optimizer_g.step()
        optimizer_f.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            base_network.eval()

            acc, py, score, y = cal_acc(dset_loaders["train"],
                                        base_network,
                                        flag=False)
            mean_ent = torch.mean(Entropy(score))
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(
                args.dset + '_train', iter_num, max_iter, acc, mean_ent)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            acc, py, score, y = cal_acc(dset_loaders["test"],
                                        base_network,
                                        flag=False)
            mean_ent = torch.mean(Entropy(score))
            list_acc.append(acc)

            if best_ent > mean_ent:
                val_acc = acc
                best_ent = mean_ent
                best_y = y
                best_py = py
                best_score = score

            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(
                args.dset + '_test', iter_num, max_iter, acc, mean_ent)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

    idx = np.argmax(np.array(list_acc))
    max_acc = list_acc[idx]
    final_acc = list_acc[-1]

    log_str = '\n==========================================\n'
    log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(
        val_acc, max_acc, final_acc)
    args.out_file.write(log_str + '\n')
    args.out_file.flush()

    # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt"))
    # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(),
    #     'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()})

    return base_network, py