def test(args, model, val_dataset, domain_num): val_dataloader = util_data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) val_dataloader_iter = enumerate(val_dataloader) val_accs_each_c = [] pred_ys = [] y_vals = [] x_val = None y_val = None model.eval() with torch.no_grad(): for j, (x_val, y_val) in val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) pred_y = model(x_val, domain_num * torch.ones_like(y_val), with_ft=False) pred_ys.append(pred_y.cpu()) # break pred_ys = torch.cat(pred_ys, 0) y_vals = torch.cat(y_vals, 0) val_acc = float(eval_utils.accuracy(pred_ys, y_vals, topk=(1, ))[0]) val_acc_each_c = [(c_name, float( eval_utils.accuracy_of_c(pred_ys, y_vals, class_idx=c, topk=(1, ))[0])) for c, c_name in enumerate(val_dataset.classes)] model.train(True) return model, val_acc
def main(): args = parse_args() torch.cuda.set_device(args.gpu) stage = args.stage global best_accuracy global best_accuracies_each_c global best_mean_val_accuracies global best_total_val_accuracies svhn_train = SVHN(root='/data/jihun/SVHN', transform=svhn_transform, download=True) svhn_val = SVHN(root='/data/jihun/SVHN', split='test', transform=svhn_transform, download=True) mnist_train = MNIST('/data/jihun/MNIST', train=True, transform=mnist_transform, download=True) mnist_val = MNIST('/data/jihun/MNIST', train=False, transform=mnist_transform, download=True) if (stage == 1): save_dir = join(save_root, args.save_dir, 'stage1') if not os.path.isdir(save_dir): os.makedirs(save_dir) if (args.trg_domain == 'mnist'): train_dataset = mnist_train val_dataset = mnist_val else: train_dataset = svhn_train val_dataset = svhn_val train_dataloader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, drop_last=True, pin_memory=True) train_dataloader_iters = enumerate(train_dataloader) model = DSBNLeNet(num_classes=10, in_features=0, num_domains=2) model.train(True) model = model.cuda(args.gpu) optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.999)) ce_loss = nn.CrossEntropyLoss() domain_num = 0 best_accuracy = 0.0 best_accuracies_each_c = [] best_mean_val_accuracies = [] best_total_val_accuracies = [] writer = SummaryWriter() for i in range(args.iters[0]): try: _, (x_s, y_s) = train_dataloader_iters.__next__() except StopIteration: train_dataloader_iters = enumerate(train_dataloader) _, (x_s, y_s) = train_dataloader_iters.__next__() optimizer.zero_grad() # lr_scheduler(optimizer, i) x_s, y_s = x_s.cuda(args.gpu), y_s.cuda(args.gpu) # x_s = x_s.cuda(args.gpu) domain_idx = torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu) pred, f = model(x_s, domain_num * domain_idx, with_ft=True) loss = ce_loss(pred, y_s) # print(loss) writer.add_scalar("Train Loss", loss, i) loss.backward() optimizer.step() if (i % 500 == 0 and i != 0): # print('------%d val start' % (i)) model.eval() total_val_accuracies = [] mean_val_accuracies = [] val_accuracies_each_c = [] model.eval() val_dataloader = util_data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) val_dataloader_iter = enumerate(val_dataloader) pred_vals = [] y_vals = [] x_val = None y_val = None with torch.no_grad(): for j, (x_val, y_val) in val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) pred_val = model(x_val, domain_num * torch.ones_like(y_val), with_ft=False) pred_vals.append(pred_val.cpu()) pred_vals = torch.cat(pred_vals, 0) y_vals = torch.cat(y_vals, 0) total_val_accuracy = float( eval_utils.accuracy(pred_vals, y_vals, topk=(1, ))[0]) val_accuracy_each_c = [ (c_name, float( eval_utils.accuracy_of_c(pred_vals, y_vals, class_idx=c, topk=(1, ))[0])) for c, c_name in enumerate(val_dataset.classes) ] mean_val_accuracy = float( torch.mean( torch.FloatTensor([ c_val_acc for _, c_val_acc in val_accuracy_each_c ]))) total_val_accuracies.append(total_val_accuracy) val_accuracies_each_c.append(val_accuracy_each_c) mean_val_accuracies.append(mean_val_accuracy) val_accuracy = float( torch.mean(torch.FloatTensor(total_val_accuracies))) print('%d th iteration accuracy: %f ' % (i, val_accuracy)) del x_val, y_val, pred_val, pred_vals, y_vals del val_dataloader_iter model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} # save best checkpoint io_utils.save_check(save_dir, i, model_dict, optimizer_dict, best=False) model.train(True) # train mode if val_accuracy > best_accuracy: best_accuracy = val_accuracy best_accuracies_each_c = val_accuracies_each_c best_mean_val_accuracies = mean_val_accuracies best_total_val_accuracies = total_val_accuracies # print('%d iter val acc %.3f' % (i, val_accuracy)) model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} # save best checkpoint io_utils.save_check(save_dir, i, model_dict, optimizer_dict, best=True) model = model.cuda(args.gpu) if args.proceed: stage += 1 if (stage == 2): save_dir = join(save_root, args.save_dir, 'stage2') if not os.path.isdir(save_dir): os.makedirs(save_dir) if (args.src_domain == 'mnist'): train_dataset = mnist_train val_dataset = mnist_val else: train_dataset = svhn_train val_dataset = svhn_val train_dataloader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) train_dataloader_iters = enumerate(train_dataloader) model = DSBNLeNet(num_classes=10, in_features=0, num_domains=2) if (args.proceed): model.load_state_dict( torch.load( join(save_root, args.save_dir, 'stage1', 'best_model.ckpt'))['model']) else: model.load_state_dict( torch.load(save_root, args.model_path)['model']) for name, p in model.named_parameters(): if ('fc' in name) or 'bns.1' in name: p.requires_grad = True continue else: p.requires_grad = False torch.nn.init.xavier_uniform_(model.fc1.weight) torch.nn.init.xavier_uniform_(model.fc2.weight) model.train(True) model = model.cuda(args.gpu) params = get_optimizer_params(model, args.learning_rate, weight_decay=args.weight_decay, double_bias_lr=True, base_weight_factor=0.1) optimizer = optim.Adam(params, betas=(0.9, 0.999)) ce_loss = nn.CrossEntropyLoss() writer = SummaryWriter() domain_num = stage - 1 print('domain_num, stage: ', domain_num, stage) best_accuracy = 0.0 best_accuracies_each_c = [] best_mean_val_accuracies = [] best_total_val_accuracies = [] for i in range(args.iters[stage - 1]): try: _, (x_s, y_s) = train_dataloader_iters.__next__() except StopIteration: train_dataloader_iters = enumerate(train_dataloader) _, (x_s, y_s) = train_dataloader_iters.__next__() optimizer.zero_grad() # lr_scheduler(optimizer, i) x_s, y_s = x_s.cuda(args.gpu), y_s.cuda(args.gpu) domain_idx = torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu) pred, f = model(x_s, domain_num * domain_idx, with_ft=True) loss = ce_loss(pred, y_s) writer.add_scalar("Train Loss", loss, i) loss.backward() optimizer.step() if (i % 500 == 0 and i != 0): # print('------%d val start' % (i)) model.eval() total_val_accuracies = [] mean_val_accuracies = [] val_accuracies_each_c = [] model.eval() val_dataloader = util_data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) val_dataloader_iter = enumerate(val_dataloader) pred_vals = [] y_vals = [] x_val = None y_val = None # print('------------------------dataload------------------------') with torch.no_grad(): for j, (x_val, y_val) in val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) pred_val = model(x_val, domain_num * torch.ones_like(y_val), with_ft=False) pred_vals.append(pred_val.cpu()) pred_vals = torch.cat(pred_vals, 0) y_vals = torch.cat(y_vals, 0) total_val_accuracy = float( eval_utils.accuracy(pred_vals, y_vals, topk=(1, ))[0]) val_accuracy_each_c = [ (c_name, float( eval_utils.accuracy_of_c(pred_vals, y_vals, class_idx=c, topk=(1, ))[0])) for c, c_name in enumerate(val_dataset.classes) ] mean_val_accuracy = float( torch.mean( torch.FloatTensor([ c_val_acc for _, c_val_acc in val_accuracy_each_c ]))) total_val_accuracies.append(total_val_accuracy) val_accuracies_each_c.append(val_accuracy_each_c) mean_val_accuracies.append(mean_val_accuracy) val_accuracy = float( torch.mean(torch.FloatTensor(total_val_accuracies))) print('%d th iteration accuracy: %f ' % (i, val_accuracy)) del x_val, y_val, pred_val, pred_vals, y_vals del val_dataloader_iter model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} # save best checkpoint io_utils.save_check(save_dir, i, model_dict, optimizer_dict, best=False) model.train(True) # train mode if val_accuracy > best_accuracy: best_accuracy = val_accuracy best_accuracies_each_c = val_accuracies_each_c best_mean_val_accuracies = mean_val_accuracies best_total_val_accuracies = total_val_accuracies # print('%d iter val acc %.3f' % (i, val_accuracy)) model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} # save best checkpoint io_utils.save_check(save_dir, i, model_dict, optimizer_dict, best=True) model = model.cuda(args.gpu) if (args.proceed): stage += 1 if (stage == 3): if (args.trg_domain == 'mnist'): val_dataset = mnist_val else: val_dataset = svhn_val val_dataloader = util_data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, drop_last=True, pin_memory=True) val_dataloader_iter = enumerate(val_dataloader) model = DSBNLeNet(num_classes=10, in_features=0, num_domains=2) if (args.proceed): model.load_state_dict( torch.load( join(save_root, args.save_dir, 'stage2', 'best_model.ckpt'))['model']) else: model.load_state_dict( torch.load(save_root, args.model_path)['model']) model = model.cuda(args.gpu) pred_vals = [] y_vals = [] domain_num = 0 with torch.no_grad(): for j, (x_val, y_val) in val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) pred_val = model(x_val, domain_num * torch.ones_like(y_val), with_ft=False) pred_vals.append(pred_val.cpu()) pred_vals = torch.cat(pred_vals, 0) y_vals = torch.cat(y_vals, 0) total_val_accuracy = float( eval_utils.accuracy(pred_vals, y_vals, topk=(1, ))[0]) val_accuracy_each_c = [(c_name, float( eval_utils.accuracy_of_c(pred_vals, y_vals, class_idx=c, topk=(1, ))[0])) for c, c_name in enumerate(val_dataset.classes)] print(total_val_accuracy) print(val_accuracy_each_c)
def main(): args = parse_args() torch.cuda.set_device(args.gpu) # set current gpu device id so pin_momory works on the target gpu if not os.path.isfile(args.model_path): raise IOError("ERROR model_path: {}".format(args.model_path)) # load checkpoints checkpoint = torch.load(args.model_path) global_step = checkpoint['iteration'] model_state_dict = checkpoint['model'] # set logger model_dir = os.path.dirname(args.model_path) log_filename = 'evaluation_step{}.log'.format(global_step) log_path = os.path.join(model_dir, log_filename) logger = io_utils.get_logger(__name__, log_file=log_path, write_level=logging.INFO, print_level=logging.INFO if args.print_console else None) # set num_classes by checking exp_setting if args.num_classes == 0: if args.exp_setting == 'digits': logger.warning('num_classes are not 10! set to 10.') args.num_classes = 10 elif args.exp_setting == 'office': logger.warning('num_classes are not 31! set to 31.') args.num_classes = 31 elif args.exp_setting in ['visda', 'imageclef']: logger.warning('num_classes are not 12! set to 12.') args.num_classes = 12 elif args.exp_setting in ['office-home']: logger.warning('num_classes are not 65! set to 65.') args.num_classes = 65 elif args.exp_setting in ['office-caltech']: args.num_classes = 10 else: raise AttributeError('Wrong num_classes: {}'.format(args.num_classes)) # update model args from filename model_args = io_utils.get_model_args_dict_from_filename(os.path.basename(args.model_path)) model_args['source_datasets'] = model_args['source_dataset'].split('|') model_args['target_datasets'] = model_args['target_dataset'].split('|') args.__dict__.update(model_args) # load args if it exists args_path = os.path.join(model_dir, 'args_dict.pth') if os.path.isfile(args_path): logger.info('Arguemnt file exist. load arguments from {}'.format(args_path)) args_dict = torch.load(args_path) update_dict = {'args_path': args_path, 'source_dataset': args_dict['source_dataset'], 'source_datasets': args_dict['source_datasets'], 'target_dataset': args_dict['target_dataset'], 'target_datasets': args_dict['target_datasets'], 'model_name': args_dict['model_name'], 'in_features': args_dict['in_features'], } args.__dict__.update(update_dict) args.dsbn = True if 'dsbn' in args.model_name else False # set dsbn logger.info('\nArguments:\n' + pprint.pformat(vars(args), indent=4)) model_options = io_utils.get_model_options_from_args(args, global_step) batch_size = args.batch_size num_classes = args.num_classes num_source_domains = len(args.source_datasets) num_target_domains = len(args.target_datasets) if args.use_tfboard: from tensorboardX import SummaryWriter base_dir = os.path.dirname(args.model_path) tfboard_dir = os.path.join(base_dir, 'tfboard') if not os.path.isdir(tfboard_dir): os.makedirs(tfboard_dir) writer = SummaryWriter(tfboard_dir) ################################################################################################################### # Data Loading # ################################################################################################################### source_test_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, source_dataset, 'test', args.jitter)) for source_dataset in args.source_datasets] target_test_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, target_dataset, 'test', args.jitter)) for target_dataset in args.target_datasets] ################################################################################################################### # Model Loading # ################################################################################################################### model = get_model(args.model_name, args.num_classes, args.in_features, pretrained=False) logger.info('Load trained parameters...') model.load_state_dict(model_state_dict) model.train(False) model.eval() model = model.cuda(args.gpu) # tfboard: write centroids if args.use_tfboard: centroids_filename = io_utils.get_centroids_filename(model_options) centroids_path = os.path.join(model_dir, centroids_filename) if os.path.isfile(centroids_path): logger.info('write centroids on tfboard: {}'.format(centroids_path)) centroids_ckpt = torch.load(centroids_path) for i, centroids in enumerate(centroids_ckpt['srcs_centroids']): src_centroids = centroids['centroids'].cpu().data.numpy() writer.add_embedding(src_centroids, metadata=list(range(num_classes)), tag='src_centroids_{}'.format(args.source_datasets[i]), global_step=global_step) trg_centroids = centroids_ckpt['trg_centroids']['centroids'].cpu().data.numpy() writer.add_embedding(trg_centroids, metadata=list(range(num_classes)), tag='trg_centroids', global_step=global_step) logger.info('Start Evaluation') results = {'step': global_step} total_features = [] total_labels = [] # for d_idx, dataset in enumerate(target_test_datasets + source_test_datasets): for d_idx, dataset in enumerate(target_test_datasets): # dataloader dataloader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False, pin_memory=True) pred_vals = [] y_vals = [] if args.use_tfboard: features = [] with torch.no_grad(): for i, (x_val, y_val) in enumerate(dataloader): x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) if args.dsbn: pred_val, f_val = model(x_val, torch.zeros_like(y_val), with_ft=True) else: pred_val, f_val = model(x_val, with_ft=True) pred_vals.append(pred_val.cpu()) y_vals.append(y_val.cpu()) if args.use_tfboard: features += [f_val.cpu().data.numpy()] pred_vals = torch.cat(pred_vals, 0) y_vals = torch.cat(y_vals, 0) test_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0]) val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals, class_idx=c, topk=(1,))[0])) for c, c_name in enumerate(dataset.classes)] # logging if d_idx <= num_target_domains: logger.info('{} Test Accuracy: {:.4f}%'.format(args.target_datasets[d_idx], 100 * test_accuracy)) logger.info('\nEach class Accuracy of {}\n'.format(args.target_datasets[d_idx]) + ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc) for c_name, c_val_acc in val_accuracy_each_c])) logger.info('Evaluation mean Accuracy: {:.2f}%'.format( 100 * float(torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c]))))) if args.save_results: results.update({args.target_datasets[d_idx]: test_accuracy}) results.update( {args.target_datasets[d_idx] + '_' + c_name: c_val_acc for c_name, c_val_acc in val_accuracy_each_c}) else: logger.info('{} Test Accuracy: {:.4f}'.format(args.source_datasets[d_idx - num_target_domains], test_accuracy)) logger.info('\nEach class Accuracy of {}\n'.format(args.source_datasets[d_idx - num_target_domains]) + ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc) for c_name, c_val_acc in val_accuracy_each_c])) logger.info('Evaluation mean Accuracy: {:.2f}%'.format( 100 * float(torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c]))))) if args.save_results: results.update({args.source_datasets[d_idx-num_target_domains]: test_accuracy}) results.update( {args.source_datasets[d_idx - num_target_domains] + '_' + c_name: c_val_acc for c_name, c_val_acc in val_accuracy_each_c}) if args.use_tfboard: features = np.concatenate(features, axis=0) y_vals_numpy = y_vals.numpy().astype(np.int) embed_features = features # u, s, vt = np.linalg.svd(features) # embed_features = np.dot(features, vt[:3, :].transpose()) if d_idx <= num_target_domains: total_features += [embed_features] total_labels += [args.target_datasets[d_idx][0] + str(int(l)) for l in y_vals] writer.add_embedding(embed_features, metadata=y_vals_numpy, tag=args.target_datasets[d_idx], global_step=global_step) else: total_features += [embed_features] total_labels += [args.source_datasets[d_idx-num_target_domains][0] + str(int(l)) for l in y_vals] writer.add_embedding(embed_features, metadata=y_vals_numpy, tag=args.source_datasets[d_idx - num_target_domains], global_step=global_step) if args.use_tfboard: total_features = np.concatenate(total_features, axis=0) writer.add_embedding(total_features, metadata=list(total_labels), tag='feat_embed_S:{}_T:{}'.format(args.source_dataset, args.target_dataset), global_step=global_step) # save results if args.save_results: result_filename = 'evaluation_{:06d}.pth'.format(global_step) torch.save(results, os.path.join(model_dir, result_filename)) if args.use_tfboard: writer.close()
def ps_test(args, teacher, student, val_dataset, domain_num): val_dataloader = util_data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) val_dataloader_iter = enumerate(val_dataloader) val_accs_each_c = [] student_accs_each_c = [] pseu_ys = [] pred_ys = [] y_vals = [] x_val = None y_val = None teacher.eval() student.eval() with torch.no_grad(): for j, (x_val, y_val) in val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) # default number 1 for original dsbn implementation # 0:src 1: trg pseu_y = teacher(x_val, 1 * torch.ones_like(y_val), with_ft=False).argmax(axis=1) pred_y = student(x_val, domain_num * torch.ones_like(y_val), with_ft=False) pseu_ys.append(pseu_y.cpu()) pred_ys.append(pred_y.cpu()) pred_ys = torch.cat(pred_ys, 0) pseu_ys = torch.cat(pseu_ys, 0) y_vals = torch.cat(y_vals, 0) val_acc = float(eval_utils.accuracy(pred_ys, y_vals, topk=(1, ))[0]) val_acc_each_c = [(c_name, float( eval_utils.accuracy_of_c(pred_ys, y_vals, class_idx=c, topk=(1, ))[0])) for c, c_name in enumerate(val_dataset.classes)] student_acc = float(eval_utils.accuracy(pred_ys, pseu_ys, topk=(1, ))[0]) student_acc_each_c = [(c_name, float( eval_utils.accuracy_of_c(pred_ys, pseu_ys, class_idx=c, topk=(1, ))[0])) for c, c_name in enumerate(val_dataset.classes)] val_accs_each_c.append(val_acc_each_c) student_accs_each_c.append(student_acc_each_c) del x_val, y_val, pred_y, pred_ys, pseu_y, pseu_ys, y_vals del val_dataloader_iter return student, val_acc, student_acc
def train(args, model, train_dataset, val_dataset, save_dir, domain_num): train_dataloader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) train_dataloader_iters = enumerate(train_dataloader) model.train(True) model = model.cuda(args.gpu) params = get_optimizer_params(model, args.learning_rate, weight_decay=args.weight_decay, double_bias_lr=True, base_weight_factor=0.1) optimizer = optim.Adam(params, betas=(0.9, 0.999)) ce_loss = nn.CrossEntropyLoss() writer = SummaryWriter(log_dir=join(save_dir, 'logs')) print('domain_num: ', domain_num) global best_accuracy global best_accuracies_each_c global best_mean_val_accuracies global best_total_val_accuracies best_accuracy = 0.0 best_accuracies_each_c = [] best_mean_val_accuracies = [] best_total_val_accuracies = [] for i in range(args.iter): try: _, (x_s, y_s) = train_dataloader_iters.__next__() except StopIteration: train_dataloader_iters = enumerate(train_dataloader) _, (x_s, y_s) = train_dataloader_iters.__next__() optimizer.zero_grad() x_s, y_s = x_s.cuda(args.gpu), y_s.cuda(args.gpu) domain_idx = torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu) pred, f = model(x_s, domain_num * domain_idx, with_ft=True) loss = ce_loss(pred, y_s) writer.add_scalar("Train Loss", loss, i) loss.backward() optimizer.step() if (i % 500 == 0 and i != 0): # print('------%d val start' % (i)) model.eval() total_val_accuracies = [] mean_val_accuracies = [] val_accuracies_each_c = [] model.eval() val_dataloader = util_data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) val_dataloader_iter = enumerate(val_dataloader) pred_vals = [] y_vals = [] x_val = None y_val = None # print('------------------------dataload------------------------') with torch.no_grad(): for j, (x_val, y_val) in val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) pred_val = model(x_val, domain_num * torch.ones_like(y_val), with_ft=False) pred_vals.append(pred_val.cpu()) pred_vals = torch.cat(pred_vals, 0) y_vals = torch.cat(y_vals, 0) total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0]) val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals, class_idx=c, topk=(1,))[0])) for c, c_name in enumerate(val_dataset.classes)] mean_val_accuracy = float( torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c]))) total_val_accuracies.append(total_val_accuracy) val_accuracies_each_c.append(val_accuracy_each_c) mean_val_accuracies.append(mean_val_accuracy) val_accuracy = float(torch.mean(torch.FloatTensor(total_val_accuracies))) print('%d th iteration accuracy: %f ' % (i, val_accuracy)) del x_val, y_val, pred_val, pred_vals, y_vals del val_dataloader_iter # train mode if val_accuracy > best_accuracy: best_accuracy = val_accuracy best_accuracies_each_c = val_accuracies_each_c best_mean_val_accuracies = mean_val_accuracies best_total_val_accuracies = total_val_accuracies # print('%d iter val acc %.3f' % (i, val_accuracy)) model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} # save best checkpoint io_utils.save_check(save_dir, i, model_dict, optimizer_dict, best=True) model.train(True) model = model.cuda(args.gpu) if (i % 10000 == 0 and i != 0): print('%d iter complete' % (i)) model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} # save best checkpoint io_utils.save_check(save_dir, i, model_dict, optimizer_dict, best=False) writer.flush() writer.close() return
def main(): args = parse_args() args.dsbn = True if 'dsbn' in args.model_name else False # set dsbn args.source_dataset = '|'.join(args.source_datasets) args.target_dataset = '|'.join(args.target_datasets) torch.cuda.set_device(args.gpu) # set current gpu device id so pin_momory works on the target gpu start_time = datetime.datetime.now() # execution start time # make save_dir if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir) # create log file log_filename = 'train_records.log' log_path = os.path.join(args.save_dir, log_filename) logger = io_utils.get_logger(__name__, log_file=log_path, write_level=logging.INFO, print_level=logging.INFO if args.print_console else None, mode='a' if args.resume else 'w') # set num_classes by checking exp_setting if args.num_classes == 0: if args.exp_setting == 'digits': logger.warning('num_classes are not 10! set to 10.') args.num_classes = 10 elif args.exp_setting == 'office': logger.warning('num_classes are not 31! set to 31.') args.num_classes = 31 elif args.exp_setting in ['visda', 'imageclef']: logger.warning('num_classes are not 12! set to 12.') args.num_classes = 12 elif args.exp_setting in ['office-home']: logger.warning('num_classes are not 65! set to 65.') args.num_classes = 65 elif args.exp_setting in ['office-caltech']: args.num_classes = 10 else: raise AttributeError('Wrong num_classes: {}'.format(args.num_classes)) if args.weight_irm > 0.0: args.weight_source_irm = args.weight_irm args.weight_target_irm = args.weight_irm if(args.iters_active_irm > 0): weight_irm_backup = [args.weight_source_irm, args.weight_target_irm] args.weight_source_irm = 0 args.weight_target_irm = 0 if args.manual_seed: # set manual seed args.manual_seed = np.uint32(args.manual_seed) torch.manual_seed(args.manual_seed) torch.cuda.manual_seed(args.manual_seed) random.seed(args.manual_seed) np.random.seed(args.manual_seed) logger.info('Random Seed: {}'.format(int(args.manual_seed))) args.random_seed = args.manual_seed # save seed into args else: seed = np.uint32(random.randrange(sys.maxsize)) torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) np.random.seed(np.uint32(seed)) logger.info('Random Seed: {}'.format(seed)) args.random_seed = seed # save seed into args if args.resume: logger.info('Resume training') else: logger.info('\nArguments:\n' + pprint.pformat(vars(args), indent=4)) # print args torch.save(vars(args), os.path.join(args.save_dir, 'args_dict.pth')) # save args num_classes = args.num_classes in_features = args.in_features if args.in_features != 0 else num_classes num_domains = len(args.source_datasets) + len(args.target_datasets) if args.merge_sources: num_source_domains = 1 else: num_source_domains = len(args.source_datasets) num_target_domains = len(args.target_datasets) # tfboard if args.use_tfboard: from tensorboardX import SummaryWriter tfboard_dir = os.path.join(args.save_dir, 'tfboard') if not os.path.isdir(tfboard_dir): os.makedirs(tfboard_dir) writer = SummaryWriter(tfboard_dir) # resume if args.resume: try: checkpoints = io_utils.load_latest_checkpoints(args.save_dir, args, logger) except FileNotFoundError: logger.warning('Latest checkpoints are not found! Trying to load best model...') checkpoints = io_utils.load_best_checkpoints(args.save_dir, args, logger) start_iter = checkpoints[0]['iteration'] + 1 else: start_iter = 1 ################################################################################################################### # Data Loading # ################################################################################################################### source_train_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, source_name, 'train', args.jitter)) for source_name in args.source_datasets] target_train_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, target_name, 'train', args.jitter)) for target_name in args.target_datasets] if args.merge_sources: for i in range(len(source_train_datasets)): if i == 0: merged_source_train_datasets = source_train_datasets[i] else: # concatenate dataset merged_source_train_datasets = merged_source_train_datasets + source_train_datasets[i] source_train_datasets = [merged_source_train_datasets] # dataloader source_train_dataloaders = [data.DataLoader(source_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) for source_train_dataset in source_train_datasets] target_train_dataloaders = [data.DataLoader(target_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) for target_train_dataset in target_train_datasets] source_train_dataloader_iters = [enumerate(source_train_dataloader) for source_train_dataloader in source_train_dataloaders] target_train_dataloader_iters = [enumerate(target_train_dataloader) for target_train_dataloader in target_train_dataloaders] # validation dataloader target_val_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, target_name, 'val', args.jitter)) for target_name in args.target_datasets] target_val_dataloaders = [data.DataLoader(target_val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) for target_val_dataset in target_val_datasets] ################################################################################################################### # Model Loading # ################################################################################################################### model = get_model(args.model_name, args.num_classes, args.in_features, num_domains=num_domains, pretrained=True) model.train(True) if args.resume: model.load_state_dict(checkpoints[0]['model']) elif args.init_model_path: init_checkpoint = torch.load(args.init_model_path) model.load_state_dict(init_checkpoint['model']) model = model.cuda(args.gpu) params = get_optimizer_params(model, args.learning_rate, weight_decay=args.weight_decay, double_bias_lr=args.double_bias_lr, base_weight_factor=args.base_weight_factor) if args.adv_loss: discriminators = [get_discriminator(args.exp_setting, in_features=args.in_features if args.in_features != 0 else args.num_classes) for _ in range(num_target_domains) for _ in range(num_source_domains)] discriminators = [discriminator.cuda(args.gpu) for discriminator in discriminators] D_params = get_optimizer_params(discriminators, args.learning_rate, weight_decay=args.weight_decay, double_bias_lr=args.double_bias_lr, base_weight_factor=None) if args.resume: if checkpoints[1]: for d_idx, discriminator in enumerate(discriminators): discriminator.load_state_dict(checkpoints[1]['discriminators'][d_idx]) if args.sm_loss: srcs_centroids = [Centroids(in_features, num_classes) for _ in range(num_source_domains)] trgs_centroids = [Centroids(in_features, num_classes) for _ in range(num_target_domains)] if args.resume: if checkpoints[2]: for src_idx, src_centroids in enumerate(srcs_centroids): src_centroids.load_state_dict(checkpoints[2]['srcs_centroids'][src_idx]) for trg_idx, trg_centroids in enumerate(trgs_centroids): trg_centroids.load_state_dict(checkpoints[2]['trgs_centroids'][trg_idx]) srcs_centroids = [src_centroids.cuda(args.gpu) for src_centroids in srcs_centroids] trgs_centroids = [trg_centroids.cuda(args.gpu) for trg_centroids in trgs_centroids] ################################################################################################################### # Train Configurations # ################################################################################################################### ce_loss = nn.CrossEntropyLoss() bce_loss = nn.BCEWithLogitsLoss() # mse_loss = nn.MSELoss() lr_scheduler = LRScheduler(args.learning_rate, args.warmup_learning_rate, args.warmup_step, num_steps=args.max_step, alpha=10, beta=0.75, double_bias_lr=args.double_bias_lr, base_weight_factor=args.base_weight_factor) if args.optimizer.lower() == 'sgd': optimizer = optim.SGD(params, momentum=0.9, nesterov=True) else: optimizer = optim.Adam(params, betas=(args.beta1, args.beta2)) if args.resume: if checkpoints[1]: optimizer.load_state_dict(checkpoints[1]['optimizer']) if args.adv_loss: if args.optimizer.lower() == 'sgd': optimizer_D = optim.SGD(D_params, momentum=0.9, nesterov=True) else: optimizer_D = optim.Adam(D_params, betas=(args.beta1, args.beta2)) if args.resume: if checkpoints[1]: optimizer_D.load_state_dict(checkpoints[1]['optimizer_D']) # Train Starts logger.info('Train Starts') domain_loss_adjust_factor = args.domain_loss_adjust_factor monitor = Monitor() global best_accuracy global best_accuracies_each_c global best_mean_val_accuracies global best_total_val_accuracies best_accuracy = 0.0 best_accuracies_each_c = [] best_mean_val_accuracies = [] best_total_val_accuracies = [] for i_iter in range(start_iter, args.early_stop_step + 1): if(args.iters_active_irm > 0): if(i_iter > args.iters_active_irm): args.weight_source_irm = weight_irm_backup[0] args.weight_target_irm = weight_irm_backup[1] src_inputs = [] for src_dataloader_idx in range(len(source_train_dataloader_iters)): try: _, (x_s, y_s) = source_train_dataloader_iters[src_dataloader_idx].__next__() src_inputs.append((x_s, y_s)) except StopIteration: source_train_dataloader_iters[src_dataloader_idx] = enumerate( source_train_dataloaders[src_dataloader_idx]) _, (x_s, y_s) = source_train_dataloader_iters[src_dataloader_idx].__next__() src_inputs.append((x_s, y_s)) trg_inputs = [] for trg_dataloader_idx in range(len(target_train_dataloader_iters)): try: _, (x_t, _) = target_train_dataloader_iters[trg_dataloader_idx].__next__() trg_inputs.append((x_t, None)) except StopIteration: target_train_dataloader_iters[trg_dataloader_idx] = enumerate( target_train_dataloaders[trg_dataloader_idx]) _, (x_t, _) = target_train_dataloader_iters[trg_dataloader_idx].__next__() trg_inputs.append((x_t, None)) current_lr = lr_scheduler.current_lr(i_iter) adaptation_lambda = adaptation_factor((i_iter - args.warmup_step) / float(args.max_step), gamma=args.adaptation_gamma) # init optimizer optimizer.zero_grad() lr_scheduler(optimizer, i_iter) if args.adv_loss: optimizer_D.zero_grad() lr_scheduler(optimizer_D, i_iter) ######################################################################################################## # Train G # ######################################################################################################## if args.adv_loss: for discriminator in discriminators: for param in discriminator.parameters(): param.requires_grad = False # ship to cuda src_inputs = [(x_s.cuda(args.gpu), y_s.cuda(args.gpu)) for (x_s, y_s) in src_inputs] trg_inputs = [(x_t.cuda(args.gpu), None) for (x_t, _) in trg_inputs] if args.dsbn: src_preds = [] for src_idx, (x_s, y_s) in enumerate(src_inputs): pred_s, f_s = model(x_s, src_idx * torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu), with_ft=True) src_preds.append((pred_s, f_s)) trg_preds = [] for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains): pred_t, f_t = model(x_t, trg_idx * torch.ones(x_t.shape[0], dtype=torch.long).cuda(args.gpu), with_ft=True) trg_preds.append((pred_t, f_t)) else: src_preds = [] for src_idx, (x_s, y_s) in enumerate(src_inputs): pred_s, f_s = model(x_s, with_ft=True) src_preds.append((pred_s, f_s)) trg_preds = [] for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains): pred_t, f_t = model(x_t, with_ft=True) trg_preds.append((pred_t, f_t)) Closs_src = 0 Closs_src_irm = 0 for (_, y_s), (pred_s, f_s) in zip(src_inputs, src_preds): Closs_src = Closs_src + ce_loss(pred_s, y_s) / float(num_source_domains) if(args.weight_source_irm > 0): Closs_src_irm += feature_penalty(f_s, model.fc, ce_loss, y_s) monitor.update({"Loss/Closs_src": float(Closs_src)}) Floss = Closs_src if(args.weight_source_irm > 0): Floss += Closs_src_irm * args.weight_source_irm monitor.update({"Loss/Closs_src_irm": float(Closs_src_irm)}) if args.adv_loss: # adversarial loss Gloss = 0 for trg_idx, (_, f_t) in enumerate(trg_preds): for src_idx, (_, f_s) in enumerate(src_preds): Dout_s = discriminators[trg_idx * num_source_domains + src_idx](f_s) source_label = torch.zeros_like(Dout_s).cuda(args.gpu) loss_adv_src = domain_loss_adjust_factor * bce_loss(Dout_s, source_label) / 2 Dout_t = discriminators[trg_idx * num_source_domains + src_idx](f_t) target_label = torch.ones_like(Dout_t).cuda(args.gpu) loss_adv_trg = domain_loss_adjust_factor * bce_loss(Dout_t, target_label) / 2 Gloss = Gloss - (loss_adv_src + loss_adv_trg) Gloss = Gloss / float(num_target_domains * num_source_domains) monitor.update({'Loss/Gloss': float(Gloss)}) Floss = Floss + adaptation_lambda * Gloss # # pseudo label generation # pred_t_pseudos = [] # if args.dsbn: # with torch.no_grad(): # model.eval() # for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains): # pred_t_pseudo = model(x_t, trg_idx * torch.ones(x_t.shape[0], dtype=torch.long).cuda(args.gpu), # with_ft=False) # pred_t_pseudos.append(pred_t_pseudo) # model.train(True) # else: # with torch.no_grad(): # model.eval() # for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains): # pred_t_pseudo = model(x_t, with_ft=False) # pred_t_pseudos.append(pred_t_pseudo) # model.train(True) # pseudo label generation pred_t_pseudos = [] if args.dsbn: with torch.no_grad(): model.eval() for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains): pred_t_pseudo = model(x_t, trg_idx * torch.ones(x_t.shape[0], dtype=torch.long).cuda(args.gpu), with_ft=False) pred_t_pseudos.append(pred_t_pseudo) model.train(True) else: with torch.no_grad(): model.eval() for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains): pred_t_pseudo = model(x_t, with_ft=False) pred_t_pseudos.append(pred_t_pseudo) model.train(True) if(args.weight_target_irm > 0): Closs_trg_irm = 0 #Closs_trg = 0 for pred_t_pseudo, (pred_t, f_t) in zip(pred_t_pseudos, trg_preds): y_t_pseudo = torch.argmax(pred_t_pseudo, 1).detach() #Closs_trg = Closs_trg + ce_loss(pred_t, y_t_pseudo) Closs_trg_irm += feature_penalty(f_t, model.fc, ce_loss, y_t_pseudo) #Floss += Closs_trg Floss += Closs_trg_irm * args.weight_target_irm monitor.update({"Loss/Closs_trg_irm": float(Closs_trg_irm)}) #monitor.update({"Loss/Closs_trg": float(Closs_trg_irm)}) # moving semantic loss if args.sm_loss: current_srcs_centroids = [src_centroids(f_s, y_s) for src_centroids, (x_s, y_s), (_, f_s) in zip(srcs_centroids, src_inputs, src_preds)] current_trgs_centroids = [trg_centroids(f_t, torch.argmax(pred_t_pseudo, 1).detach()) for trg_centroids, pred_t_pseudo, (_, f_t) in zip(trgs_centroids, pred_t_pseudos, trg_preds)] semantic_loss = 0 for current_trg_centroids in current_trgs_centroids: for current_src_centroids in current_srcs_centroids: semantic_loss = semantic_loss + args.sm_etha * semantic_loss_calc(current_src_centroids, current_trg_centroids) semantic_loss = semantic_loss / float(num_target_domains * num_source_domains) monitor.update({'Loss/SMloss': float(semantic_loss)}) Floss = Floss + adaptation_lambda * semantic_loss # Floss backward Floss.backward() optimizer.step() ######################################################################################################## # Train D # ######################################################################################################## if args.adv_loss: for discriminator in discriminators: for param in discriminator.parameters(): param.requires_grad = True if args.adv_loss: # adversarial loss Dloss = 0 for trg_idx, (_, f_t) in enumerate(trg_preds): for src_idx, (_, f_s) in enumerate(src_preds): Dout_s = discriminators[trg_idx * num_source_domains + src_idx](f_s.detach()) source_label = torch.zeros_like(Dout_s).cuda(args.gpu) loss_adv_src = domain_loss_adjust_factor * bce_loss(Dout_s, source_label) / 2 # target Dout_t = discriminators[trg_idx * num_source_domains + src_idx](f_t.detach()) target_label = torch.ones_like(Dout_t).cuda(args.gpu) loss_adv_trg = domain_loss_adjust_factor * bce_loss(Dout_t, target_label) / 2 Dloss = Dloss + loss_adv_src + loss_adv_trg Dloss = Dloss / float(num_target_domains * num_source_domains) monitor.update({'Loss/Dloss': float(Dloss)}) Dloss = adaptation_lambda * Dloss Dloss.backward() optimizer_D.step() if args.sm_loss: for src_centroids, current_src_centroids in zip(srcs_centroids, current_srcs_centroids): src_centroids.centroids.data = current_src_centroids.data for trg_centroids, current_trg_centroids in zip(trgs_centroids, current_trgs_centroids): trg_centroids.centroids.data = current_trg_centroids.data if i_iter % args.disp_interval == 0 and i_iter != 0: disp_msg = 'iter[{:8d}/{:8d}], '.format(i_iter, args.early_stop_step) disp_msg += str(monitor) if args.adv_loss or args.sm_loss: disp_msg += ', lambda={:.6f}'.format(adaptation_lambda) disp_msg += ', lr={:.6f}'.format(current_lr) logger.info(disp_msg) if args.use_tfboard: if args.save_model_hist: for name, param in model.named_parameters(): writer.add_histogram(name, param.cpu().data.numpy(), i_iter, bins='auto') for k, v in monitor.losses.items(): writer.add_scalar(k, v, i_iter) if args.adv_loss or args.sm_loss: writer.add_scalar('adaptation_lambda', adaptation_lambda, i_iter) writer.add_scalar('learning rate', current_lr, i_iter) monitor.reset() if i_iter % args.save_interval == 0 and i_iter != 0: logger.info("Elapsed Time: {}".format(datetime.datetime.now() - start_time)) logger.info("Start Evaluation at {:d}".format(i_iter)) target_val_dataloader_iters = [enumerate(target_val_dataloader) for target_val_dataloader in target_val_dataloaders] total_val_accuracies = [] mean_val_accuracies = [] val_accuracies_each_c = [] model.eval() # evaluation mode for trg_idx, target_val_dataloader_iter in enumerate(target_val_dataloader_iters, num_source_domains): pred_vals = [] y_vals = [] x_val = None y_val = None pred_val = None with torch.no_grad(): for i, (x_val, y_val) in target_val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) if args.dsbn: pred_val = model(x_val, trg_idx * torch.ones_like(y_val), with_ft=False) else: pred_val = model(x_val, with_ft=False) pred_vals.append(pred_val.cpu()) pred_vals = torch.cat(pred_vals, 0) y_vals = torch.cat(y_vals, 0) total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0]) val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals, class_idx=c, topk=(1,))[0])) for c, c_name in enumerate(target_val_datasets[trg_idx - num_source_domains].classes)] logger.info('\n{} Accuracy of Each class\n'.format(args.target_datasets[trg_idx - num_source_domains]) + ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc) for c_name, c_val_acc in val_accuracy_each_c])) mean_val_accuracy = float( torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c]))) logger.info('{} mean Accuracy: {:.2f}%'.format( args.target_datasets[trg_idx - num_source_domains], 100 * mean_val_accuracy)) logger.info( '{} Accuracy: {:.2f}%'.format(args.target_datasets[trg_idx - num_source_domains], total_val_accuracy * 100)) total_val_accuracies.append(total_val_accuracy) val_accuracies_each_c.append(val_accuracy_each_c) mean_val_accuracies.append(mean_val_accuracy) if args.use_tfboard: writer.add_scalar('Val_acc', total_val_accuracy, i_iter) for c_name, c_val_acc in val_accuracy_each_c: writer.add_scalar('Val_acc_of_{}'.format(c_name), c_val_acc) model.train(True) # train mode if args.exp_setting.lower() == 'visda': val_accuracy = float(torch.mean(torch.FloatTensor(mean_val_accuracies))) else: val_accuracy = float(torch.mean(torch.FloatTensor(total_val_accuracies))) # for memory del x_val, y_val, pred_val, pred_vals, y_vals for target_val_dataloader_iter in target_val_dataloader_iters: del target_val_dataloader_iter del target_val_dataloader_iters if val_accuracy > best_accuracy: # save best model best_accuracy = val_accuracy best_accuracies_each_c = val_accuracies_each_c best_mean_val_accuracies = mean_val_accuracies best_total_val_accuracies = total_val_accuracies options = io_utils.get_model_options_from_args(args, i_iter) # dict to save model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} if args.adv_loss: optimizer_dict.update({'optimizer_D': optimizer_D.state_dict(), 'discriminators': [discriminator.cpu().state_dict() for discriminator in discriminators], 'source_datasets': args.source_datasets, 'target_datasets': args.target_datasets}) centroids_dict = {} if args.sm_loss: centroids_dict = { 'srcs_centroids': [src_centroids.cpu().state_dict() for src_centroids in srcs_centroids], 'trgs_centroids': [trg_centroids.cpu().state_dict() for trg_centroids in trgs_centroids]} # save best checkpoint io_utils.save_checkpoints(args.save_dir, options, i_iter, model_dict, optimizer_dict, centroids_dict, logger, best=True) # ship to cuda model = model.cuda(args.gpu) if args.adv_loss: discriminators = [discriminator.cuda(args.gpu) for discriminator in discriminators] if args.sm_loss: srcs_centroids = [src_centroids.cuda(args.gpu) for src_centroids in srcs_centroids] trgs_centroids = [trg_centroids.cuda(args.gpu) for trg_centroids in trgs_centroids] # save best result into textfile contents = [' '.join(sys.argv) + '\n', "best accuracy: {:.2f}%\n".format(best_accuracy)] for d_idx in range(num_target_domains): best_accuracy_each_c = best_accuracies_each_c[d_idx] best_mean_val_accuracy = best_mean_val_accuracies[d_idx] best_total_val_accuracy = best_total_val_accuracies[d_idx] contents.extend(["{}2{}\n".format(args.source_dataset, args.target_datasets[d_idx]), "best total acc: {:.2f}%\n".format(100 * best_total_val_accuracy), "best mean acc: {:.2f}%\n".format(100 * best_mean_val_accuracy), 'Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc) for _, c_val_acc in best_accuracy_each_c]) + '\n']) best_result_path = os.path.join('./output', '{}_best_result.txt'.format( os.path.splitext(os.path.basename(__file__))[0])) with open(best_result_path, 'a+') as f: f.writelines(contents) # logging best model results for trg_idx in range(num_target_domains): best_accuracy_each_c = best_accuracies_each_c[trg_idx] best_total_val_accuracy = best_total_val_accuracies[trg_idx] best_mean_val_accuracy = best_mean_val_accuracies[trg_idx] logger.info( '\nBest {} Accuracy of Each class\n'.format(args.target_datasets[trg_idx]) + ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc) for c_name, c_val_acc in best_accuracy_each_c])) logger.info('Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc) for _, c_val_acc in best_accuracy_each_c])) logger.info('Best {} mean Accuracy: {:.2f}%'.format(args.target_datasets[trg_idx], 100 * best_mean_val_accuracy)) logger.info('Best {} Accuracy: {:.2f}%'.format(args.target_datasets[trg_idx], 100 * best_total_val_accuracy)) logger.info("Best model's Average Accuracy of targets: {:.2f}".format(100 * best_accuracy)) if args.save_ckpts: # get options options = io_utils.get_model_options_from_args(args, i_iter) # dict to save model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} if args.adv_loss: optimizer_dict.update({'optimizer_D': optimizer_D.state_dict(), 'discriminators': [discriminator.cpu().state_dict() for discriminator in discriminators]}) centroids_dict = {} if args.sm_loss: centroids_dict = { 'srcs_centroids': [src_centroids.cpu().state_dict() for src_centroids in srcs_centroids], 'trgs_centroids': [trg_centroids.cpu().state_dict() for trg_centroids in trgs_centroids]} # save checkpoint io_utils.save_checkpoints(args.save_dir, options, i_iter, model_dict, optimizer_dict, centroids_dict, logger, best=False) # ship to cuda model = model.cuda(args.gpu) if args.adv_loss: discriminators = [discriminator.cuda(args.gpu) for discriminator in discriminators] if args.sm_loss: srcs_centroids = [src_centroids.cuda(args.gpu) for src_centroids in srcs_centroids] trgs_centroids = [trg_centroids.cuda(args.gpu) for trg_centroids in trgs_centroids] if args.use_tfboard: writer.close() logger.info('Total Time: {}'.format((datetime.datetime.now() - start_time)))
def main(): args = parse_args() stage = args.stage torch.cuda.set_device(args.gpu) writer = SummaryWriter() save_dir = args.save_dir if not os.path.isdir(save_dir): os.makedirs(save_dir) print('domain: ', args.trg_domain) num_domain = len(args.trg_domain) if (args.ssl): model = get_rot_model(args.model_name, num_domains=1) train_dataset = rot_dataset(args.data_root, num_domain, args.trg_domain, 'train') val_dataset = rot_dataset(args.data_root, num_domain, args.trg_domain, 'val') test1_dataset = rot_dataset(args.data_root, 1, [args.trg_domain[0]], 'test') if (len(args.trg_domain) > 1): test2_dataset = rot_dataset(args.data_root, 1, [args.trg_domain[1]], 'test') else: model = get_model(args.model_name, 65, 65, 1, pretrained=True) train_dataset = OFFICEHOME_multi(args.data_root, num_domain, args.trg_domain, transform=train_transform) val_dataset = OFFICEHOME_multi(args.data_root, num_domain, args.trg_domain, transform=val_transform) test1_dataset = OFFICEHOME_multi(args.data_root, 1, [args.trg_domain[0]], transform=val_transform) if (len(args.trg_domain) > 1): test2_dataset = OFFICEHOME_multi(args.data_root, 1, [args.trg_domain[1]], transform=val_transform) train_dataloader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) train_dataloader_iter = enumerate(train_dataloader) model.train(True) model = model.cuda(args.gpu) ce_loss = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.999)) global best_accuracy global best_accuracies_each_c global best_mean_val_accuracies global best_total_val_accuracies best_accuracy = 0.0 best_accuracies_each_c = [] best_mean_val_accuracies = [] best_total_val_accuracies = [] for i in range(args.iters[0]): try: _, (x_s, y_s) = train_dataloader_iter.__next__() except StopIteration: train_dataloader_iter = enumerate(train_dataloader) _, (x_s, y_s) = train_dataloader_iter.__next__() optimizer.zero_grad() x_s, y_s = x_s.cuda(args.gpu), y_s.cuda(args.gpu) domain_idx = torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu) pred, f = model(x_s, 0 * domain_idx, with_ft=True) loss = ce_loss(pred, y_s) loss.backward() optimizer.step() if (i % 500 == 0 and i != 0): # print('------%d val start' % (i)) model.eval() total_val_accuracies = [] mean_val_accuracies = [] val_accuracies_each_c = [] model.eval() val_dataloader = util_data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) val_dataloader_iter = enumerate(val_dataloader) pred_vals = [] y_vals = [] x_val = None y_val = None # print('------------------------dataload------------------------') with torch.no_grad(): for j, (x_val, y_val) in val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) pred_val = model(x_val, 0 * torch.ones_like(y_val), with_ft=False) pred_vals.append(pred_val.cpu()) pred_vals = torch.cat(pred_vals, 0) y_vals = torch.cat(y_vals, 0) total_val_accuracy = float( eval_utils.accuracy(pred_vals, y_vals, topk=(1, ))[0]) val_accuracy_each_c = [ (c_name, float( eval_utils.accuracy_of_c(pred_vals, y_vals, class_idx=c, topk=(1, ))[0])) for c, c_name in enumerate(val_dataset.classes) ] mean_val_accuracy = float( torch.mean( torch.FloatTensor( [c_val_acc for _, c_val_acc in val_accuracy_each_c]))) total_val_accuracies.append(total_val_accuracy) val_accuracies_each_c.append(val_accuracy_each_c) mean_val_accuracies.append(mean_val_accuracy) val_accuracy = float( torch.mean(torch.FloatTensor(total_val_accuracies))) print('%d th iteration accuracy: %f ' % (i, val_accuracy)) del x_val, y_val, pred_val, pred_vals, y_vals del val_dataloader_iter model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} # save best checkpoint # io_utils.save_check(save_dir, i, model_dict, optimizer_dict, best=False) model.train(True) # train mode if val_accuracy > best_accuracy: best_accuracy = val_accuracy best_accuracies_each_c = val_accuracies_each_c best_mean_val_accuracies = mean_val_accuracies best_total_val_accuracies = total_val_accuracies # print('%d iter val acc %.3f' % (i, val_accuracy)) model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} # save best checkpoint io_utils.save_check(save_dir, i, model_dict, optimizer_dict, best=True) model = model.cuda(args.gpu) if (i % 5000 == 0 and i != 0): print('%d iter complete' % (i)) test(args, test1_dataset, model) if (len(args.trg_domain) > 1): test(args, test2_dataset, model) writer.flush() writer.close() model.eval() test(args, test1_dataset, model) if (len(args.trg_domain) > 1): test(args, test2_dataset, model)
def main(): print('start finetune') args = parse_args() args.dsbn = True if 'dsbn' in args.model_name else False # set dsbn args.cpua = True if 'cpua' in args.model_name else False torch.cuda.set_device(args.gpu) # set current gpu device id so pin_momory works on the target gpu start_time = datetime.datetime.now() # execution start time # make save_dir if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir) # check whether teacher model exists if not os.path.isfile(args.teacher_model_path): raise AttributeError('Missing teacher model path: {}'.format(args.teacher_model_path)) # create log file log_filename = 'train_records.log' log_path = os.path.join(args.save_dir, log_filename) logger = io_utils.get_logger(__name__, log_file=log_path, write_level=logging.INFO, print_level=logging.INFO if args.print_console else None, mode='a' if args.resume else 'w') # set num_classes by checking exp_setting if args.num_classes == 0: if args.exp_setting in ['office-home']: logger.warning('num_classes are not 65! set to 65.') args.num_classes = 65 elif args.exp_setting in ['digits']: logger.warning('num_classes are not 10! set to 10.') args.num_classes = 10 else: raise AttributeError('Wrong num_classes: {}'.format(args.num_classes)) if args.manual_seed: # set manual seed args.manual_seed = np.uint32(args.manual_seed) torch.manual_seed(args.manual_seed) torch.cuda.manual_seed(args.manual_seed) random.seed(args.manual_seed) np.random.seed(args.manual_seed) logger.info('Random Seed: {}'.format(int(args.manual_seed))) args.random_seed = args.manual_seed # save seed into args else: seed = np.uint32(random.randrange(sys.maxsize)) torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) np.random.seed(np.uint32(seed)) logger.info('Random Seed: {}'.format(seed)) args.random_seed = seed # save seed into args if args.resume: logger.info('Resume training') else: logger.info('\nArguments:\n' + pprint.pformat(vars(args), indent=4)) # print args torch.save(vars(args), os.path.join(args.save_dir, 'args_dict.pth')) # save args num_classes = args.num_classes in_features = args.in_features if args.in_features != 0 else num_classes # num_domains = len(args.source_datasets) + len(args.target_datasets) # tfboard if args.use_tfboard: from tensorboardX import SummaryWriter tfboard_dir = os.path.join(args.save_dir, 'tfboard') if not os.path.isdir(tfboard_dir): os.makedirs(tfboard_dir) writer = SummaryWriter(tfboard_dir) # resume if args.resume: try: checkpoints = io_utils.load_latest_checkpoints(args.save_dir, args, logger) except FileNotFoundError: logger.warning('Latest checkpoints are not found! Trying to load best model...') checkpoints = io_utils.load_best_checkpoints(args.save_dir, args, logger) start_iter = checkpoints[0]['iteration'] + 1 else: start_iter = 1 ################################################################################################################### # Data Loading # ################################################################################################################### # train_dataset = MNIST('/data/jihun/MNIST', train=True, transform=mnist_transform, download=True) # val_dataset = MNIST('/data/jihun/MNIST', train=False, transform=mnist_transform, download=True) train_dataset = SVHN(root='/data/jihun/SVHN', transform=svhn_transform, download=True) val_dataset = SVHN(root='/data/jihun/SVHN', split='test', transform=svhn_transform, download=True) train_dataloader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) val_dataloader = util_data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) train_dataloader_iters = enumerate(train_dataloader) val_dataloader_iter = enumerate(val_dataloader) ################################################################################################################### # Model Loading # ################################################################################################################### model = get_model(args.model_name, 10, 0, 2, pretrained=True) params = get_optimizer_params(model, args.learning_rate, weight_decay=args.weight_decay, double_bias_lr=args.double_bias_lr, base_weight_factor=args.base_weight_factor) # teacher model # print(teacher_model) print('------------------------model load------------------------') model.load_state_dict(torch.load(args.teacher_model_path)['model']) for name, p in model.state_dict().items(): if ('fc' in name) or 'bns.1' in name: continue else: p.requires_grad = False torch.nn.init.xavier_uniform_(model.fc1.weight) torch.nn.init.xavier_uniform_(model.fc2.weight) model.train(True) model = model.cuda(args.gpu) ################################################################################################################### # Train Configurations # ################################################################################################################### ce_loss = nn.CrossEntropyLoss() optimizer = optim.Adam(params, betas=(args.beta1, args.beta2)) # Train Starts logger.info('Train Starts') monitor = Monitor() global best_accuracy global best_accuracies_each_c global best_mean_val_accuracies global best_total_val_accuracies best_accuracy = 0.0 best_accuracies_each_c = [] best_mean_val_accuracies = [] best_total_val_accuracies = [] for i in range(start_iter, args.early_stop_step + 1): try: _, (x_s, y_s) = train_dataloader_iters.__next__() except StopIteration: train_dataloader_iters = enumerate(train_dataloader) _, (x_s, y_s) = train_dataloader_iters.__next__() # init optimizer optimizer.zero_grad() # ship to cuda x_s, y_s = x_s.cuda(args.gpu), y_s.cuda(args.gpu) pred_s, f_s = model(x_s, 1 * torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu), with_ft=True) loss = ce_loss(pred_s, y_s) monitor.update({"Loss/Closs_src": float(loss)}) loss.backward() optimizer.step() if i % args.save_interval == 0 and i != 0: # print('------------------------%d val start------------------------' % (i)) logger.info("Elapsed Time: {}".format(datetime.datetime.now() - start_time)) logger.info("Start Evaluation at {:d}".format(i)) total_val_accuracies = [] mean_val_accuracies = [] val_accuracies_each_c = [] model.eval() val_dataloader_iter = enumerate(val_dataloader) pred_vals = [] y_vals = [] x_val = None y_val = None # print('------------------------dataload------------------------') with torch.no_grad(): for j, (x_val, y_val) in val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) pred_val = model(x_val, 1 * torch.ones_like(y_val), with_ft=False) pred_vals.append(pred_val.cpu()) # print('------------------------acc compute------------------------') pred_vals = torch.cat(pred_vals, 0) y_vals = torch.cat(y_vals, 0) total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0]) val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals, class_idx=c, topk=(1,))[0])) for c, c_name in enumerate(val_dataset.classes)] logger.info('\n{} Accuracy of Each class\n'.format(args.finetune_dataset) + ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc) for c_name, c_val_acc in val_accuracy_each_c])) mean_val_accuracy = float( torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c]))) # print('------------------------mean acc------------------------') logger.info('{} mean Accuracy: {:.2f}%'.format( args.finetune_dataset, 100 * mean_val_accuracy)) logger.info( '{} Accuracy: {:.2f}%'.format(args.finetune_dataset, total_val_accuracy * 100)) total_val_accuracies.append(total_val_accuracy) val_accuracies_each_c.append(val_accuracy_each_c) mean_val_accuracies.append(mean_val_accuracy) # print('------------------------tf board------------------------') if args.use_tfboard: writer.add_scalar('Val_acc', total_val_accuracy, i) for c_name, c_val_acc in val_accuracy_each_c: writer.add_scalar('Val_acc_of_{}'.format(c_name), c_val_acc) model.train(True) # train mode val_accuracy = float(torch.mean(torch.FloatTensor(total_val_accuracies))) del x_val, y_val, pred_val, pred_vals, y_vals del val_dataloader_iter print("%d th iter accuracy: %.3f" % (i, val_accuracy)) # print('------------------------save model------------------------') if val_accuracy > best_accuracy: # save best model best_accuracy = val_accuracy best_accuracies_each_c = val_accuracies_each_c best_mean_val_accuracies = mean_val_accuracies best_total_val_accuracies = total_val_accuracies options = io_utils.get_model_options_from_args(args, i) # dict to save model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} # save best checkpoint io_utils.save_checkpoints(args.save_dir, options, i, model_dict, optimizer_dict, logger, best=True) # ship to cuda model = model.cuda(args.gpu) # save best result into textfile contents = [' '.join(sys.argv) + '\n', "best accuracy: {:.2f}%\n".format(best_accuracy)] best_accuracy_each_c = best_accuracies_each_c[0] best_mean_val_accuracy = best_mean_val_accuracies[0] best_total_val_accuracy = best_total_val_accuracies[0] contents.extend(["{}\n".format(args.finetune_dataset), "best total acc: {:.2f}%\n".format(100 * best_total_val_accuracy), "best mean acc: {:.2f}%\n".format(100 * best_mean_val_accuracy), 'Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc) for _, c_val_acc in best_accuracy_each_c]) + '\n']) best_result_path = os.path.join('./output', '{}_best_result.txt'.format( os.path.splitext(os.path.basename(__file__))[0])) with open(best_result_path, 'a+') as f: f.writelines(contents) val_dataloader_iter = enumerate(val_dataloader) pred_vals = [] y_vals = [] x_val = None y_val = None # print('------------------------dataload------------------------') with torch.no_grad(): for j, (x_val, y_val) in val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) pred_val = model(x_val, 0 * torch.ones_like(y_val), with_ft=False) pred_vals.append(pred_val.cpu()) pred_vals = torch.cat(pred_vals, 0) y_vals = torch.cat(y_vals, 0) total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0]) val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals, class_idx=c, topk=(1,))[0])) for c, c_name in enumerate(val_dataset.classes)] for cls in val_accuracy_each_c: print(cls) print(total_val_accuracy)