def train(args): ## set pre-process dset_loaders = data_load(args) max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"])) args.max_iter = args.max_epoch * max_len ## set base network if args.net == 'resnet34': netG = utils.ResBase34().cuda() elif args.net == 'vgg16': netG = utils.VGG16Base().cuda() netF = utils.ResClassifier(class_num=args.class_num, feature_dim=netG.in_features, bottleneck_dim=args.bottleneck_dim).cuda() if len(args.gpu_id.split(',')) > 1: netG = nn.DataParallel(netG) optimizer_g = optim.SGD(netG.parameters(), lr=args.lr * 0.1) optimizer_f = optim.SGD(netF.parameters(), lr=args.lr) base_network = nn.Sequential(netG, netF) source_loader_iter = iter(dset_loaders["source"]) target_loader_iter = iter(dset_loaders["target"]) ltarget_loader_iter = iter(dset_loaders["ltarget"]) if args.pl.startswith('atdoc_na'): mem_fea = torch.rand( len(dset_loaders["target"].dataset) + len(dset_loaders["ltarget"].dataset), args.bottleneck_dim).cuda() mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) mem_cls = torch.ones( len(dset_loaders["target"].dataset) + len(dset_loaders["ltarget"].dataset), args.class_num).cuda() / args.class_num if args.pl == 'atdoc_nc': mem_fea = torch.rand(args.class_num, args.bottleneck_dim).cuda() mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True) list_acc = [] best_val_acc = 0 for iter_num in range(1, args.max_iter + 1): base_network.train() lr_scheduler(optimizer_g, init_lr=args.lr * 0.1, iter_num=iter_num, max_iter=args.max_iter) lr_scheduler(optimizer_f, init_lr=args.lr, iter_num=iter_num, max_iter=args.max_iter) try: inputs_source, labels_source = source_loader_iter.next() except: source_loader_iter = iter(dset_loaders["source"]) inputs_source, labels_source = source_loader_iter.next() try: inputs_target, _, target_idx = target_loader_iter.next() except: target_loader_iter = iter(dset_loaders["target"]) inputs_target, _, target_idx = target_loader_iter.next() try: inputs_ltarget, labels_ltarget, lidx = ltarget_loader_iter.next() except: ltarget_loader_iter = iter(dset_loaders["ltarget"]) inputs_ltarget, labels_ltarget, lidx = ltarget_loader_iter.next() inputs_lt = inputs_ltarget[0].cuda() inputs_lt2 = inputs_ltarget[1].cuda() targets_lt = torch.zeros(args.batch_size // 3, args.class_num).scatter_( 1, labels_ltarget.view(-1, 1), 1) targets_lt = targets_lt.cuda() targets_s = torch.zeros(args.batch_size, args.class_num).scatter_( 1, labels_source.view(-1, 1), 1) inputs_s = inputs_source.cuda() targets_s = targets_s.cuda() inputs_t = inputs_target[0].cuda() inputs_t2 = inputs_target[1].cuda() if args.pl.startswith('atdoc_na'): targets_u = 0 for inp in [inputs_t, inputs_t2]: with torch.no_grad(): features_target, outputs_u = base_network(inp) dis = -torch.mm(features_target.detach(), mem_fea.t()) for di in range(dis.size(0)): dis[di, target_idx[di]] = torch.max(dis) # dis[di, target_idx[di]+len(dset_loaders["target"].dataset)] = torch.max(dis) _, p1 = torch.sort(dis, dim=1) w = torch.zeros(features_target.size(0), mem_fea.size(0)).cuda() for wi in range(w.size(0)): for wj in range(args.K): w[wi][p1[wi, wj]] = 1 / args.K _, pred = torch.max(w.mm(mem_cls), 1) targets_u += 0.5 * torch.eye(outputs_u.size(1))[pred].cuda() elif args.pl == 'atdoc_nc': targets_u = 0 mem_fea_norm = mem_fea / torch.norm( mem_fea, p=2, dim=1, keepdim=True) for inp in [inputs_t, inputs_t2]: with torch.no_grad(): features_target, outputs_u = base_network(inp) dis = torch.mm(features_target.detach(), mem_fea_norm.t()) _, pred = torch.max(dis, dim=1) targets_u += 0.5 * torch.eye(outputs_u.size(1))[pred].cuda() elif args.pl == 'npl': targets_u = 0 for inp in [inputs_t, inputs_t2]: with torch.no_grad(): _, outputs_u = base_network(inp) _, pred = torch.max(outputs_u.detach(), 1) targets_u += 0.5 * torch.eye(outputs_u.size(1))[pred].cuda() else: with torch.no_grad(): # compute guessed labels of unlabel samples _, outputs_u = base_network(inputs_t) _, outputs_u2 = base_network(inputs_t2) p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 pt = p**(1 / args.T) targets_u = pt / pt.sum(dim=1, keepdim=True) targets_u = targets_u.detach() #################################################################### all_inputs = torch.cat( [inputs_s, inputs_lt, inputs_t, inputs_lt2, inputs_t2], dim=0) all_targets = torch.cat( [targets_s, targets_lt, targets_u, targets_lt, targets_u], dim=0) if args.alpha > 0: l = np.random.beta(args.alpha, args.alpha) l = max(l, 1 - l) else: l = 1 idx = torch.randperm(all_inputs.size(0)) input_a, input_b = all_inputs, all_inputs[idx] target_a, target_b = all_targets, all_targets[idx] mixed_input = l * input_a + (1 - l) * input_b mixed_target = l * target_a + (1 - l) * target_b # interleave labeled and unlabed samples between batches to get correct batchnorm calculation mixed_input = list(torch.split(mixed_input, args.batch_size)) mixed_input = utils.interleave(mixed_input, args.batch_size) # s = [sa, sb, sc] # t1 = [t1a, t1b, t1c] # t2 = [t2a, t2b, t2c] # => s' = [sa, t1b, t2c] t1' = [t1a, sb, t1c] t2' = [t2a, t2b, sc] # _, logits = base_network(mixed_input[0]) features, logits = base_network(mixed_input[0]) logits = [logits] for input in mixed_input[1:]: _, temp = base_network(input) logits.append(temp) # put interleaved samples back # [i[:,0] for i in aa] logits = utils.interleave(logits, args.batch_size) logits_x = logits[0] logits_u = torch.cat(logits[1:], dim=0) train_criterion = utils.SemiLoss() Lx, Lu, w = train_criterion(logits_x, mixed_target[:args.batch_size], logits_u, mixed_target[args.batch_size:], iter_num, args.max_iter, args.lambda_u) loss = Lx + w * Lu optimizer_g.zero_grad() optimizer_f.zero_grad() loss.backward() optimizer_g.step() optimizer_f.step() if args.pl.startswith('atdoc_na'): base_network.eval() with torch.no_grad(): fea1, outputs1 = base_network(inputs_t) fea2, outputs2 = base_network(inputs_t2) feat = 0.5 * (fea1 + fea2) feat = feat / torch.norm(feat, p=2, dim=1, keepdim=True) softmax_out = 0.5 * (nn.Softmax(dim=1)(outputs1) + nn.Softmax(dim=1)(outputs2)) softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0)) mem_fea[target_idx] = ( 1.0 - args.momentum) * mem_fea[target_idx] + args.momentum * feat mem_cls[target_idx] = (1.0 - args.momentum) * mem_cls[ target_idx] + args.momentum * softmax_out with torch.no_grad(): fea1, outputs1 = base_network(inputs_lt) fea2, outputs2 = base_network(inputs_lt2) feat = 0.5 * (fea1 + fea2) feat = feat / torch.norm(feat, p=2, dim=1, keepdim=True) softmax_out = 0.5 * (nn.Softmax(dim=1)(outputs1) + nn.Softmax(dim=1)(outputs2)) softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0)) mem_fea[lidx + len(dset_loaders["target"].dataset)] = (1.0 - args.momentum) * \ mem_fea[lidx + len(dset_loaders["target"].dataset)] + args.momentum*feat mem_cls[lidx + len(dset_loaders["target"].dataset)] = (1.0 - args.momentum) * \ mem_cls[lidx + len(dset_loaders["target"].dataset)] + args.momentum*softmax_out if args.pl == 'atdoc_nc': base_network.eval() with torch.no_grad(): fea1, outputs1 = base_network(inputs_t) fea2, outputs2 = base_network(inputs_t2) feat_u = 0.5 * (fea1 + fea2) softmax_t = 0.5 * (nn.Softmax(dim=1)(outputs1) + nn.Softmax(dim=1)(outputs2)) _, pred_t = torch.max(softmax_t, 1) onehot_tu = torch.eye(args.class_num)[pred_t].cuda() with torch.no_grad(): fea1, outputs1 = base_network(inputs_lt) fea2, outputs2 = base_network(inputs_lt2) feat_l = 0.5 * (fea1 + fea2) softmax_t = 0.5 * (nn.Softmax(dim=1)(outputs1) + nn.Softmax(dim=1)(outputs2)) _, pred_t = torch.max(softmax_t, 1) onehot_tl = torch.eye(args.class_num)[pred_t].cuda() # onehot_tl = torch.eye(args.class_num)[labels_ltarget].cuda() center_t = ((torch.mm(feat_u.t(), onehot_tu) + torch.mm( feat_l.t(), onehot_tl))) / (onehot_tu.sum(dim=0) + onehot_tl.sum(dim=0) + 1e-8) mem_fea = (1.0 - args.momentum ) * mem_fea + args.momentum * center_t.t().clone() if iter_num % int(args.eval_epoch * max_len) == 0: base_network.eval() if args.dset == 'VISDA-C': acc, py, score, y, tacc = utils.cal_acc_visda( dset_loaders["test"], base_network) args.out_file.write(tacc + '\n') args.out_file.flush() else: acc, py, score, y = utils.cal_acc(dset_loaders["test"], base_network) val_acc, _, _, _ = utils.cal_acc(dset_loaders["val"], base_network) list_acc.append(acc * 100) if best_val_acc <= val_acc: best_val_acc = val_acc best_acc = acc best_y = y best_py = py best_score = score log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Val Acc = {:.2f}%'.format( args.name, iter_num, args.max_iter, acc * 100, val_acc * 100) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') val_acc = best_acc * 100 idx = np.argmax(np.array(list_acc)) max_acc = list_acc[idx] final_acc = list_acc[-1] log_str = '\n==========================================\n' log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format( val_acc, max_acc, final_acc) args.out_file.write(log_str + '\n') args.out_file.flush() # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt")) # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(), # 'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()}) return base_network, py
def train(args, txt_src, txt_tgt): ## set pre-process dset_loaders = data_load(args, txt_src, txt_tgt) # pdb.set_trace() max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"])) max_iter = args.max_epoch * max_len interval_iter = max_iter // 10 if args.dset == 'u2m': netG = network.LeNetBase().cuda() elif args.dset == 'm2u': netG = network.LeNetBase().cuda() elif args.dset == 's2m': netG = network.DTNBase().cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netG.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() if args.model == 'source': modelpath = args.output_dir + "/source_F.pt" netG.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/source_B.pt" netB.load_state_dict(torch.load(modelpath)) else: modelpath = args.output_dir + "/target_F_" + args.savename + ".pt" netG.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/target_B_" + args.savename + ".pt" netB.load_state_dict(torch.load(modelpath)) netF = nn.Sequential(netB, netC) optimizer_g = optim.SGD(netG.parameters(), lr=args.lr * 0.1) optimizer_f = optim.SGD(netF.parameters(), lr=args.lr) base_network = nn.Sequential(netG, netF) source_loader_iter = iter(dset_loaders["source"]) target_loader_iter = iter(dset_loaders["target"]) list_acc = [] best_ent = 100 for iter_num in range(1, max_iter + 1): base_network.train() lr_scheduler(optimizer_g, init_lr=args.lr * 0.1, iter_num=iter_num, max_iter=max_iter) lr_scheduler(optimizer_f, init_lr=args.lr, iter_num=iter_num, max_iter=max_iter) try: inputs_source, labels_source = source_loader_iter.next() except: source_loader_iter = iter(dset_loaders["source"]) inputs_source, labels_source = source_loader_iter.next() try: inputs_target, _, target_idx = target_loader_iter.next() except: target_loader_iter = iter(dset_loaders["target"]) inputs_target, _, target_idx = target_loader_iter.next() targets_s = torch.zeros(args.batch_size, args.class_num).scatter_( 1, labels_source.view(-1, 1), 1) inputs_s = inputs_source.cuda() targets_s = targets_s.cuda() inputs_t = inputs_target[0].cuda() inputs_t2 = inputs_target[1].cuda() with torch.no_grad(): # compute guessed labels of unlabel samples outputs_u = base_network(inputs_t) outputs_u2 = base_network(inputs_t2) p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 pt = p**(1 / args.T) targets_u = pt / pt.sum(dim=1, keepdim=True) targets_u = targets_u.detach() #################################################################### all_inputs = torch.cat([inputs_s, inputs_t, inputs_t2], dim=0) all_targets = torch.cat([targets_s, targets_u, targets_u], dim=0) if args.alpha > 0: l = np.random.beta(args.alpha, args.alpha) l = max(l, 1 - l) else: l = 1 idx = torch.randperm(all_inputs.size(0)) input_a, input_b = all_inputs, all_inputs[idx] target_a, target_b = all_targets, all_targets[idx] mixed_input = l * input_a + (1 - l) * input_b mixed_target = l * target_a + (1 - l) * target_b # interleave labeled and unlabed samples between batches to get correct batchnorm calculation mixed_input = list(torch.split(mixed_input, args.batch_size)) mixed_input = utils.interleave(mixed_input, args.batch_size) # s = [sa, sb, sc] # t1 = [t1a, t1b, t1c] # t2 = [t2a, t2b, t2c] # => s' = [sa, t1b, t2c] t1' = [t1a, sb, t1c] t2' = [t2a, t2b, sc] logits = base_network(mixed_input[0]) logits = [logits] for input in mixed_input[1:]: temp = base_network(input) logits.append(temp) # put interleaved samples back # [i[:,0] for i in aa] logits = utils.interleave(logits, args.batch_size) logits_x = logits[0] logits_u = torch.cat(logits[1:], dim=0) train_criterion = utils.SemiLoss() Lx, Lu, w = train_criterion(logits_x, mixed_target[:args.batch_size], logits_u, mixed_target[args.batch_size:], iter_num, max_iter, args.lambda_u) loss = Lx + w * Lu optimizer_g.zero_grad() optimizer_f.zero_grad() loss.backward() optimizer_g.step() optimizer_f.step() if iter_num % interval_iter == 0 or iter_num == max_iter: base_network.eval() acc, py, score, y = cal_acc(dset_loaders["train"], base_network, flag=False) mean_ent = torch.mean(Entropy(score)) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format( args.dset + '_train', iter_num, max_iter, acc, mean_ent) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') acc, py, score, y = cal_acc(dset_loaders["test"], base_network, flag=False) mean_ent = torch.mean(Entropy(score)) list_acc.append(acc) if best_ent > mean_ent: val_acc = acc best_ent = mean_ent best_y = y best_py = py best_score = score log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format( args.dset + '_test', iter_num, max_iter, acc, mean_ent) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') idx = np.argmax(np.array(list_acc)) max_acc = list_acc[idx] final_acc = list_acc[-1] log_str = '\n==========================================\n' log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format( val_acc, max_acc, final_acc) args.out_file.write(log_str + '\n') args.out_file.flush() # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt")) # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(), # 'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()}) return base_network, py