def main(): args = parse_args() args.dsbn = True if 'dsbn' in args.model_name else False # set dsbn args.source_dataset = '|'.join(args.source_datasets) args.target_dataset = '|'.join(args.target_datasets) torch.cuda.set_device(args.gpu) # set current gpu device id so pin_momory works on the target gpu start_time = datetime.datetime.now() # execution start time # make save_dir if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir) # create log file log_filename = 'train_records.log' log_path = os.path.join(args.save_dir, log_filename) logger = io_utils.get_logger(__name__, log_file=log_path, write_level=logging.INFO, print_level=logging.INFO if args.print_console else None, mode='a' if args.resume else 'w') # set num_classes by checking exp_setting if args.num_classes == 0: if args.exp_setting == 'digits': logger.warning('num_classes are not 10! set to 10.') args.num_classes = 10 elif args.exp_setting == 'office': logger.warning('num_classes are not 31! set to 31.') args.num_classes = 31 elif args.exp_setting in ['visda', 'imageclef']: logger.warning('num_classes are not 12! set to 12.') args.num_classes = 12 elif args.exp_setting in ['office-home']: logger.warning('num_classes are not 65! set to 65.') args.num_classes = 65 elif args.exp_setting in ['office-caltech']: args.num_classes = 10 else: raise AttributeError('Wrong num_classes: {}'.format(args.num_classes)) if args.weight_irm > 0.0: args.weight_source_irm = args.weight_irm args.weight_target_irm = args.weight_irm if(args.iters_active_irm > 0): weight_irm_backup = [args.weight_source_irm, args.weight_target_irm] args.weight_source_irm = 0 args.weight_target_irm = 0 if args.manual_seed: # set manual seed args.manual_seed = np.uint32(args.manual_seed) torch.manual_seed(args.manual_seed) torch.cuda.manual_seed(args.manual_seed) random.seed(args.manual_seed) np.random.seed(args.manual_seed) logger.info('Random Seed: {}'.format(int(args.manual_seed))) args.random_seed = args.manual_seed # save seed into args else: seed = np.uint32(random.randrange(sys.maxsize)) torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) np.random.seed(np.uint32(seed)) logger.info('Random Seed: {}'.format(seed)) args.random_seed = seed # save seed into args if args.resume: logger.info('Resume training') else: logger.info('\nArguments:\n' + pprint.pformat(vars(args), indent=4)) # print args torch.save(vars(args), os.path.join(args.save_dir, 'args_dict.pth')) # save args num_classes = args.num_classes in_features = args.in_features if args.in_features != 0 else num_classes num_domains = len(args.source_datasets) + len(args.target_datasets) if args.merge_sources: num_source_domains = 1 else: num_source_domains = len(args.source_datasets) num_target_domains = len(args.target_datasets) # tfboard if args.use_tfboard: from tensorboardX import SummaryWriter tfboard_dir = os.path.join(args.save_dir, 'tfboard') if not os.path.isdir(tfboard_dir): os.makedirs(tfboard_dir) writer = SummaryWriter(tfboard_dir) # resume if args.resume: try: checkpoints = io_utils.load_latest_checkpoints(args.save_dir, args, logger) except FileNotFoundError: logger.warning('Latest checkpoints are not found! Trying to load best model...') checkpoints = io_utils.load_best_checkpoints(args.save_dir, args, logger) start_iter = checkpoints[0]['iteration'] + 1 else: start_iter = 1 ################################################################################################################### # Data Loading # ################################################################################################################### source_train_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, source_name, 'train', args.jitter)) for source_name in args.source_datasets] target_train_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, target_name, 'train', args.jitter)) for target_name in args.target_datasets] if args.merge_sources: for i in range(len(source_train_datasets)): if i == 0: merged_source_train_datasets = source_train_datasets[i] else: # concatenate dataset merged_source_train_datasets = merged_source_train_datasets + source_train_datasets[i] source_train_datasets = [merged_source_train_datasets] # dataloader source_train_dataloaders = [data.DataLoader(source_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) for source_train_dataset in source_train_datasets] target_train_dataloaders = [data.DataLoader(target_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) for target_train_dataset in target_train_datasets] source_train_dataloader_iters = [enumerate(source_train_dataloader) for source_train_dataloader in source_train_dataloaders] target_train_dataloader_iters = [enumerate(target_train_dataloader) for target_train_dataloader in target_train_dataloaders] # validation dataloader target_val_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, target_name, 'val', args.jitter)) for target_name in args.target_datasets] target_val_dataloaders = [data.DataLoader(target_val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) for target_val_dataset in target_val_datasets] ################################################################################################################### # Model Loading # ################################################################################################################### model = get_model(args.model_name, args.num_classes, args.in_features, num_domains=num_domains, pretrained=True) model.train(True) if args.resume: model.load_state_dict(checkpoints[0]['model']) elif args.init_model_path: init_checkpoint = torch.load(args.init_model_path) model.load_state_dict(init_checkpoint['model']) model = model.cuda(args.gpu) params = get_optimizer_params(model, args.learning_rate, weight_decay=args.weight_decay, double_bias_lr=args.double_bias_lr, base_weight_factor=args.base_weight_factor) if args.adv_loss: discriminators = [get_discriminator(args.exp_setting, in_features=args.in_features if args.in_features != 0 else args.num_classes) for _ in range(num_target_domains) for _ in range(num_source_domains)] discriminators = [discriminator.cuda(args.gpu) for discriminator in discriminators] D_params = get_optimizer_params(discriminators, args.learning_rate, weight_decay=args.weight_decay, double_bias_lr=args.double_bias_lr, base_weight_factor=None) if args.resume: if checkpoints[1]: for d_idx, discriminator in enumerate(discriminators): discriminator.load_state_dict(checkpoints[1]['discriminators'][d_idx]) if args.sm_loss: srcs_centroids = [Centroids(in_features, num_classes) for _ in range(num_source_domains)] trgs_centroids = [Centroids(in_features, num_classes) for _ in range(num_target_domains)] if args.resume: if checkpoints[2]: for src_idx, src_centroids in enumerate(srcs_centroids): src_centroids.load_state_dict(checkpoints[2]['srcs_centroids'][src_idx]) for trg_idx, trg_centroids in enumerate(trgs_centroids): trg_centroids.load_state_dict(checkpoints[2]['trgs_centroids'][trg_idx]) srcs_centroids = [src_centroids.cuda(args.gpu) for src_centroids in srcs_centroids] trgs_centroids = [trg_centroids.cuda(args.gpu) for trg_centroids in trgs_centroids] ################################################################################################################### # Train Configurations # ################################################################################################################### ce_loss = nn.CrossEntropyLoss() bce_loss = nn.BCEWithLogitsLoss() # mse_loss = nn.MSELoss() lr_scheduler = LRScheduler(args.learning_rate, args.warmup_learning_rate, args.warmup_step, num_steps=args.max_step, alpha=10, beta=0.75, double_bias_lr=args.double_bias_lr, base_weight_factor=args.base_weight_factor) if args.optimizer.lower() == 'sgd': optimizer = optim.SGD(params, momentum=0.9, nesterov=True) else: optimizer = optim.Adam(params, betas=(args.beta1, args.beta2)) if args.resume: if checkpoints[1]: optimizer.load_state_dict(checkpoints[1]['optimizer']) if args.adv_loss: if args.optimizer.lower() == 'sgd': optimizer_D = optim.SGD(D_params, momentum=0.9, nesterov=True) else: optimizer_D = optim.Adam(D_params, betas=(args.beta1, args.beta2)) if args.resume: if checkpoints[1]: optimizer_D.load_state_dict(checkpoints[1]['optimizer_D']) # Train Starts logger.info('Train Starts') domain_loss_adjust_factor = args.domain_loss_adjust_factor monitor = Monitor() global best_accuracy global best_accuracies_each_c global best_mean_val_accuracies global best_total_val_accuracies best_accuracy = 0.0 best_accuracies_each_c = [] best_mean_val_accuracies = [] best_total_val_accuracies = [] for i_iter in range(start_iter, args.early_stop_step + 1): if(args.iters_active_irm > 0): if(i_iter > args.iters_active_irm): args.weight_source_irm = weight_irm_backup[0] args.weight_target_irm = weight_irm_backup[1] src_inputs = [] for src_dataloader_idx in range(len(source_train_dataloader_iters)): try: _, (x_s, y_s) = source_train_dataloader_iters[src_dataloader_idx].__next__() src_inputs.append((x_s, y_s)) except StopIteration: source_train_dataloader_iters[src_dataloader_idx] = enumerate( source_train_dataloaders[src_dataloader_idx]) _, (x_s, y_s) = source_train_dataloader_iters[src_dataloader_idx].__next__() src_inputs.append((x_s, y_s)) trg_inputs = [] for trg_dataloader_idx in range(len(target_train_dataloader_iters)): try: _, (x_t, _) = target_train_dataloader_iters[trg_dataloader_idx].__next__() trg_inputs.append((x_t, None)) except StopIteration: target_train_dataloader_iters[trg_dataloader_idx] = enumerate( target_train_dataloaders[trg_dataloader_idx]) _, (x_t, _) = target_train_dataloader_iters[trg_dataloader_idx].__next__() trg_inputs.append((x_t, None)) current_lr = lr_scheduler.current_lr(i_iter) adaptation_lambda = adaptation_factor((i_iter - args.warmup_step) / float(args.max_step), gamma=args.adaptation_gamma) # init optimizer optimizer.zero_grad() lr_scheduler(optimizer, i_iter) if args.adv_loss: optimizer_D.zero_grad() lr_scheduler(optimizer_D, i_iter) ######################################################################################################## # Train G # ######################################################################################################## if args.adv_loss: for discriminator in discriminators: for param in discriminator.parameters(): param.requires_grad = False # ship to cuda src_inputs = [(x_s.cuda(args.gpu), y_s.cuda(args.gpu)) for (x_s, y_s) in src_inputs] trg_inputs = [(x_t.cuda(args.gpu), None) for (x_t, _) in trg_inputs] if args.dsbn: src_preds = [] for src_idx, (x_s, y_s) in enumerate(src_inputs): pred_s, f_s = model(x_s, src_idx * torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu), with_ft=True) src_preds.append((pred_s, f_s)) trg_preds = [] for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains): pred_t, f_t = model(x_t, trg_idx * torch.ones(x_t.shape[0], dtype=torch.long).cuda(args.gpu), with_ft=True) trg_preds.append((pred_t, f_t)) else: src_preds = [] for src_idx, (x_s, y_s) in enumerate(src_inputs): pred_s, f_s = model(x_s, with_ft=True) src_preds.append((pred_s, f_s)) trg_preds = [] for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains): pred_t, f_t = model(x_t, with_ft=True) trg_preds.append((pred_t, f_t)) Closs_src = 0 Closs_src_irm = 0 for (_, y_s), (pred_s, f_s) in zip(src_inputs, src_preds): Closs_src = Closs_src + ce_loss(pred_s, y_s) / float(num_source_domains) if(args.weight_source_irm > 0): Closs_src_irm += feature_penalty(f_s, model.fc, ce_loss, y_s) monitor.update({"Loss/Closs_src": float(Closs_src)}) Floss = Closs_src if(args.weight_source_irm > 0): Floss += Closs_src_irm * args.weight_source_irm monitor.update({"Loss/Closs_src_irm": float(Closs_src_irm)}) if args.adv_loss: # adversarial loss Gloss = 0 for trg_idx, (_, f_t) in enumerate(trg_preds): for src_idx, (_, f_s) in enumerate(src_preds): Dout_s = discriminators[trg_idx * num_source_domains + src_idx](f_s) source_label = torch.zeros_like(Dout_s).cuda(args.gpu) loss_adv_src = domain_loss_adjust_factor * bce_loss(Dout_s, source_label) / 2 Dout_t = discriminators[trg_idx * num_source_domains + src_idx](f_t) target_label = torch.ones_like(Dout_t).cuda(args.gpu) loss_adv_trg = domain_loss_adjust_factor * bce_loss(Dout_t, target_label) / 2 Gloss = Gloss - (loss_adv_src + loss_adv_trg) Gloss = Gloss / float(num_target_domains * num_source_domains) monitor.update({'Loss/Gloss': float(Gloss)}) Floss = Floss + adaptation_lambda * Gloss # # pseudo label generation # pred_t_pseudos = [] # if args.dsbn: # with torch.no_grad(): # model.eval() # for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains): # pred_t_pseudo = model(x_t, trg_idx * torch.ones(x_t.shape[0], dtype=torch.long).cuda(args.gpu), # with_ft=False) # pred_t_pseudos.append(pred_t_pseudo) # model.train(True) # else: # with torch.no_grad(): # model.eval() # for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains): # pred_t_pseudo = model(x_t, with_ft=False) # pred_t_pseudos.append(pred_t_pseudo) # model.train(True) # pseudo label generation pred_t_pseudos = [] if args.dsbn: with torch.no_grad(): model.eval() for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains): pred_t_pseudo = model(x_t, trg_idx * torch.ones(x_t.shape[0], dtype=torch.long).cuda(args.gpu), with_ft=False) pred_t_pseudos.append(pred_t_pseudo) model.train(True) else: with torch.no_grad(): model.eval() for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains): pred_t_pseudo = model(x_t, with_ft=False) pred_t_pseudos.append(pred_t_pseudo) model.train(True) if(args.weight_target_irm > 0): Closs_trg_irm = 0 #Closs_trg = 0 for pred_t_pseudo, (pred_t, f_t) in zip(pred_t_pseudos, trg_preds): y_t_pseudo = torch.argmax(pred_t_pseudo, 1).detach() #Closs_trg = Closs_trg + ce_loss(pred_t, y_t_pseudo) Closs_trg_irm += feature_penalty(f_t, model.fc, ce_loss, y_t_pseudo) #Floss += Closs_trg Floss += Closs_trg_irm * args.weight_target_irm monitor.update({"Loss/Closs_trg_irm": float(Closs_trg_irm)}) #monitor.update({"Loss/Closs_trg": float(Closs_trg_irm)}) # moving semantic loss if args.sm_loss: current_srcs_centroids = [src_centroids(f_s, y_s) for src_centroids, (x_s, y_s), (_, f_s) in zip(srcs_centroids, src_inputs, src_preds)] current_trgs_centroids = [trg_centroids(f_t, torch.argmax(pred_t_pseudo, 1).detach()) for trg_centroids, pred_t_pseudo, (_, f_t) in zip(trgs_centroids, pred_t_pseudos, trg_preds)] semantic_loss = 0 for current_trg_centroids in current_trgs_centroids: for current_src_centroids in current_srcs_centroids: semantic_loss = semantic_loss + args.sm_etha * semantic_loss_calc(current_src_centroids, current_trg_centroids) semantic_loss = semantic_loss / float(num_target_domains * num_source_domains) monitor.update({'Loss/SMloss': float(semantic_loss)}) Floss = Floss + adaptation_lambda * semantic_loss # Floss backward Floss.backward() optimizer.step() ######################################################################################################## # Train D # ######################################################################################################## if args.adv_loss: for discriminator in discriminators: for param in discriminator.parameters(): param.requires_grad = True if args.adv_loss: # adversarial loss Dloss = 0 for trg_idx, (_, f_t) in enumerate(trg_preds): for src_idx, (_, f_s) in enumerate(src_preds): Dout_s = discriminators[trg_idx * num_source_domains + src_idx](f_s.detach()) source_label = torch.zeros_like(Dout_s).cuda(args.gpu) loss_adv_src = domain_loss_adjust_factor * bce_loss(Dout_s, source_label) / 2 # target Dout_t = discriminators[trg_idx * num_source_domains + src_idx](f_t.detach()) target_label = torch.ones_like(Dout_t).cuda(args.gpu) loss_adv_trg = domain_loss_adjust_factor * bce_loss(Dout_t, target_label) / 2 Dloss = Dloss + loss_adv_src + loss_adv_trg Dloss = Dloss / float(num_target_domains * num_source_domains) monitor.update({'Loss/Dloss': float(Dloss)}) Dloss = adaptation_lambda * Dloss Dloss.backward() optimizer_D.step() if args.sm_loss: for src_centroids, current_src_centroids in zip(srcs_centroids, current_srcs_centroids): src_centroids.centroids.data = current_src_centroids.data for trg_centroids, current_trg_centroids in zip(trgs_centroids, current_trgs_centroids): trg_centroids.centroids.data = current_trg_centroids.data if i_iter % args.disp_interval == 0 and i_iter != 0: disp_msg = 'iter[{:8d}/{:8d}], '.format(i_iter, args.early_stop_step) disp_msg += str(monitor) if args.adv_loss or args.sm_loss: disp_msg += ', lambda={:.6f}'.format(adaptation_lambda) disp_msg += ', lr={:.6f}'.format(current_lr) logger.info(disp_msg) if args.use_tfboard: if args.save_model_hist: for name, param in model.named_parameters(): writer.add_histogram(name, param.cpu().data.numpy(), i_iter, bins='auto') for k, v in monitor.losses.items(): writer.add_scalar(k, v, i_iter) if args.adv_loss or args.sm_loss: writer.add_scalar('adaptation_lambda', adaptation_lambda, i_iter) writer.add_scalar('learning rate', current_lr, i_iter) monitor.reset() if i_iter % args.save_interval == 0 and i_iter != 0: logger.info("Elapsed Time: {}".format(datetime.datetime.now() - start_time)) logger.info("Start Evaluation at {:d}".format(i_iter)) target_val_dataloader_iters = [enumerate(target_val_dataloader) for target_val_dataloader in target_val_dataloaders] total_val_accuracies = [] mean_val_accuracies = [] val_accuracies_each_c = [] model.eval() # evaluation mode for trg_idx, target_val_dataloader_iter in enumerate(target_val_dataloader_iters, num_source_domains): pred_vals = [] y_vals = [] x_val = None y_val = None pred_val = None with torch.no_grad(): for i, (x_val, y_val) in target_val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) if args.dsbn: pred_val = model(x_val, trg_idx * torch.ones_like(y_val), with_ft=False) else: pred_val = model(x_val, with_ft=False) pred_vals.append(pred_val.cpu()) pred_vals = torch.cat(pred_vals, 0) y_vals = torch.cat(y_vals, 0) total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0]) val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals, class_idx=c, topk=(1,))[0])) for c, c_name in enumerate(target_val_datasets[trg_idx - num_source_domains].classes)] logger.info('\n{} Accuracy of Each class\n'.format(args.target_datasets[trg_idx - num_source_domains]) + ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc) for c_name, c_val_acc in val_accuracy_each_c])) mean_val_accuracy = float( torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c]))) logger.info('{} mean Accuracy: {:.2f}%'.format( args.target_datasets[trg_idx - num_source_domains], 100 * mean_val_accuracy)) logger.info( '{} Accuracy: {:.2f}%'.format(args.target_datasets[trg_idx - num_source_domains], total_val_accuracy * 100)) total_val_accuracies.append(total_val_accuracy) val_accuracies_each_c.append(val_accuracy_each_c) mean_val_accuracies.append(mean_val_accuracy) if args.use_tfboard: writer.add_scalar('Val_acc', total_val_accuracy, i_iter) for c_name, c_val_acc in val_accuracy_each_c: writer.add_scalar('Val_acc_of_{}'.format(c_name), c_val_acc) model.train(True) # train mode if args.exp_setting.lower() == 'visda': val_accuracy = float(torch.mean(torch.FloatTensor(mean_val_accuracies))) else: val_accuracy = float(torch.mean(torch.FloatTensor(total_val_accuracies))) # for memory del x_val, y_val, pred_val, pred_vals, y_vals for target_val_dataloader_iter in target_val_dataloader_iters: del target_val_dataloader_iter del target_val_dataloader_iters if val_accuracy > best_accuracy: # save best model best_accuracy = val_accuracy best_accuracies_each_c = val_accuracies_each_c best_mean_val_accuracies = mean_val_accuracies best_total_val_accuracies = total_val_accuracies options = io_utils.get_model_options_from_args(args, i_iter) # dict to save model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} if args.adv_loss: optimizer_dict.update({'optimizer_D': optimizer_D.state_dict(), 'discriminators': [discriminator.cpu().state_dict() for discriminator in discriminators], 'source_datasets': args.source_datasets, 'target_datasets': args.target_datasets}) centroids_dict = {} if args.sm_loss: centroids_dict = { 'srcs_centroids': [src_centroids.cpu().state_dict() for src_centroids in srcs_centroids], 'trgs_centroids': [trg_centroids.cpu().state_dict() for trg_centroids in trgs_centroids]} # save best checkpoint io_utils.save_checkpoints(args.save_dir, options, i_iter, model_dict, optimizer_dict, centroids_dict, logger, best=True) # ship to cuda model = model.cuda(args.gpu) if args.adv_loss: discriminators = [discriminator.cuda(args.gpu) for discriminator in discriminators] if args.sm_loss: srcs_centroids = [src_centroids.cuda(args.gpu) for src_centroids in srcs_centroids] trgs_centroids = [trg_centroids.cuda(args.gpu) for trg_centroids in trgs_centroids] # save best result into textfile contents = [' '.join(sys.argv) + '\n', "best accuracy: {:.2f}%\n".format(best_accuracy)] for d_idx in range(num_target_domains): best_accuracy_each_c = best_accuracies_each_c[d_idx] best_mean_val_accuracy = best_mean_val_accuracies[d_idx] best_total_val_accuracy = best_total_val_accuracies[d_idx] contents.extend(["{}2{}\n".format(args.source_dataset, args.target_datasets[d_idx]), "best total acc: {:.2f}%\n".format(100 * best_total_val_accuracy), "best mean acc: {:.2f}%\n".format(100 * best_mean_val_accuracy), 'Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc) for _, c_val_acc in best_accuracy_each_c]) + '\n']) best_result_path = os.path.join('./output', '{}_best_result.txt'.format( os.path.splitext(os.path.basename(__file__))[0])) with open(best_result_path, 'a+') as f: f.writelines(contents) # logging best model results for trg_idx in range(num_target_domains): best_accuracy_each_c = best_accuracies_each_c[trg_idx] best_total_val_accuracy = best_total_val_accuracies[trg_idx] best_mean_val_accuracy = best_mean_val_accuracies[trg_idx] logger.info( '\nBest {} Accuracy of Each class\n'.format(args.target_datasets[trg_idx]) + ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc) for c_name, c_val_acc in best_accuracy_each_c])) logger.info('Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc) for _, c_val_acc in best_accuracy_each_c])) logger.info('Best {} mean Accuracy: {:.2f}%'.format(args.target_datasets[trg_idx], 100 * best_mean_val_accuracy)) logger.info('Best {} Accuracy: {:.2f}%'.format(args.target_datasets[trg_idx], 100 * best_total_val_accuracy)) logger.info("Best model's Average Accuracy of targets: {:.2f}".format(100 * best_accuracy)) if args.save_ckpts: # get options options = io_utils.get_model_options_from_args(args, i_iter) # dict to save model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} if args.adv_loss: optimizer_dict.update({'optimizer_D': optimizer_D.state_dict(), 'discriminators': [discriminator.cpu().state_dict() for discriminator in discriminators]}) centroids_dict = {} if args.sm_loss: centroids_dict = { 'srcs_centroids': [src_centroids.cpu().state_dict() for src_centroids in srcs_centroids], 'trgs_centroids': [trg_centroids.cpu().state_dict() for trg_centroids in trgs_centroids]} # save checkpoint io_utils.save_checkpoints(args.save_dir, options, i_iter, model_dict, optimizer_dict, centroids_dict, logger, best=False) # ship to cuda model = model.cuda(args.gpu) if args.adv_loss: discriminators = [discriminator.cuda(args.gpu) for discriminator in discriminators] if args.sm_loss: srcs_centroids = [src_centroids.cuda(args.gpu) for src_centroids in srcs_centroids] trgs_centroids = [trg_centroids.cuda(args.gpu) for trg_centroids in trgs_centroids] if args.use_tfboard: writer.close() logger.info('Total Time: {}'.format((datetime.datetime.now() - start_time)))
def main(): print('start finetune') args = parse_args() args.dsbn = True if 'dsbn' in args.model_name else False # set dsbn args.cpua = True if 'cpua' in args.model_name else False torch.cuda.set_device(args.gpu) # set current gpu device id so pin_momory works on the target gpu start_time = datetime.datetime.now() # execution start time # make save_dir if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir) # check whether teacher model exists if not os.path.isfile(args.teacher_model_path): raise AttributeError('Missing teacher model path: {}'.format(args.teacher_model_path)) # create log file log_filename = 'train_records.log' log_path = os.path.join(args.save_dir, log_filename) logger = io_utils.get_logger(__name__, log_file=log_path, write_level=logging.INFO, print_level=logging.INFO if args.print_console else None, mode='a' if args.resume else 'w') # set num_classes by checking exp_setting if args.num_classes == 0: if args.exp_setting in ['office-home']: logger.warning('num_classes are not 65! set to 65.') args.num_classes = 65 elif args.exp_setting in ['digits']: logger.warning('num_classes are not 10! set to 10.') args.num_classes = 10 else: raise AttributeError('Wrong num_classes: {}'.format(args.num_classes)) if args.manual_seed: # set manual seed args.manual_seed = np.uint32(args.manual_seed) torch.manual_seed(args.manual_seed) torch.cuda.manual_seed(args.manual_seed) random.seed(args.manual_seed) np.random.seed(args.manual_seed) logger.info('Random Seed: {}'.format(int(args.manual_seed))) args.random_seed = args.manual_seed # save seed into args else: seed = np.uint32(random.randrange(sys.maxsize)) torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) np.random.seed(np.uint32(seed)) logger.info('Random Seed: {}'.format(seed)) args.random_seed = seed # save seed into args if args.resume: logger.info('Resume training') else: logger.info('\nArguments:\n' + pprint.pformat(vars(args), indent=4)) # print args torch.save(vars(args), os.path.join(args.save_dir, 'args_dict.pth')) # save args num_classes = args.num_classes in_features = args.in_features if args.in_features != 0 else num_classes # num_domains = len(args.source_datasets) + len(args.target_datasets) # tfboard if args.use_tfboard: from tensorboardX import SummaryWriter tfboard_dir = os.path.join(args.save_dir, 'tfboard') if not os.path.isdir(tfboard_dir): os.makedirs(tfboard_dir) writer = SummaryWriter(tfboard_dir) # resume if args.resume: try: checkpoints = io_utils.load_latest_checkpoints(args.save_dir, args, logger) except FileNotFoundError: logger.warning('Latest checkpoints are not found! Trying to load best model...') checkpoints = io_utils.load_best_checkpoints(args.save_dir, args, logger) start_iter = checkpoints[0]['iteration'] + 1 else: start_iter = 1 ################################################################################################################### # Data Loading # ################################################################################################################### # train_dataset = MNIST('/data/jihun/MNIST', train=True, transform=mnist_transform, download=True) # val_dataset = MNIST('/data/jihun/MNIST', train=False, transform=mnist_transform, download=True) train_dataset = SVHN(root='/data/jihun/SVHN', transform=svhn_transform, download=True) val_dataset = SVHN(root='/data/jihun/SVHN', split='test', transform=svhn_transform, download=True) train_dataloader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) val_dataloader = util_data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) train_dataloader_iters = enumerate(train_dataloader) val_dataloader_iter = enumerate(val_dataloader) ################################################################################################################### # Model Loading # ################################################################################################################### model = get_model(args.model_name, 10, 0, 2, pretrained=True) params = get_optimizer_params(model, args.learning_rate, weight_decay=args.weight_decay, double_bias_lr=args.double_bias_lr, base_weight_factor=args.base_weight_factor) # teacher model # print(teacher_model) print('------------------------model load------------------------') model.load_state_dict(torch.load(args.teacher_model_path)['model']) for name, p in model.state_dict().items(): if ('fc' in name) or 'bns.1' in name: continue else: p.requires_grad = False torch.nn.init.xavier_uniform_(model.fc1.weight) torch.nn.init.xavier_uniform_(model.fc2.weight) model.train(True) model = model.cuda(args.gpu) ################################################################################################################### # Train Configurations # ################################################################################################################### ce_loss = nn.CrossEntropyLoss() optimizer = optim.Adam(params, betas=(args.beta1, args.beta2)) # Train Starts logger.info('Train Starts') monitor = Monitor() global best_accuracy global best_accuracies_each_c global best_mean_val_accuracies global best_total_val_accuracies best_accuracy = 0.0 best_accuracies_each_c = [] best_mean_val_accuracies = [] best_total_val_accuracies = [] for i in range(start_iter, args.early_stop_step + 1): try: _, (x_s, y_s) = train_dataloader_iters.__next__() except StopIteration: train_dataloader_iters = enumerate(train_dataloader) _, (x_s, y_s) = train_dataloader_iters.__next__() # init optimizer optimizer.zero_grad() # ship to cuda x_s, y_s = x_s.cuda(args.gpu), y_s.cuda(args.gpu) pred_s, f_s = model(x_s, 1 * torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu), with_ft=True) loss = ce_loss(pred_s, y_s) monitor.update({"Loss/Closs_src": float(loss)}) loss.backward() optimizer.step() if i % args.save_interval == 0 and i != 0: # print('------------------------%d val start------------------------' % (i)) logger.info("Elapsed Time: {}".format(datetime.datetime.now() - start_time)) logger.info("Start Evaluation at {:d}".format(i)) total_val_accuracies = [] mean_val_accuracies = [] val_accuracies_each_c = [] model.eval() val_dataloader_iter = enumerate(val_dataloader) pred_vals = [] y_vals = [] x_val = None y_val = None # print('------------------------dataload------------------------') with torch.no_grad(): for j, (x_val, y_val) in val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) pred_val = model(x_val, 1 * torch.ones_like(y_val), with_ft=False) pred_vals.append(pred_val.cpu()) # print('------------------------acc compute------------------------') pred_vals = torch.cat(pred_vals, 0) y_vals = torch.cat(y_vals, 0) total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0]) val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals, class_idx=c, topk=(1,))[0])) for c, c_name in enumerate(val_dataset.classes)] logger.info('\n{} Accuracy of Each class\n'.format(args.finetune_dataset) + ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc) for c_name, c_val_acc in val_accuracy_each_c])) mean_val_accuracy = float( torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c]))) # print('------------------------mean acc------------------------') logger.info('{} mean Accuracy: {:.2f}%'.format( args.finetune_dataset, 100 * mean_val_accuracy)) logger.info( '{} Accuracy: {:.2f}%'.format(args.finetune_dataset, total_val_accuracy * 100)) total_val_accuracies.append(total_val_accuracy) val_accuracies_each_c.append(val_accuracy_each_c) mean_val_accuracies.append(mean_val_accuracy) # print('------------------------tf board------------------------') if args.use_tfboard: writer.add_scalar('Val_acc', total_val_accuracy, i) for c_name, c_val_acc in val_accuracy_each_c: writer.add_scalar('Val_acc_of_{}'.format(c_name), c_val_acc) model.train(True) # train mode val_accuracy = float(torch.mean(torch.FloatTensor(total_val_accuracies))) del x_val, y_val, pred_val, pred_vals, y_vals del val_dataloader_iter print("%d th iter accuracy: %.3f" % (i, val_accuracy)) # print('------------------------save model------------------------') if val_accuracy > best_accuracy: # save best model best_accuracy = val_accuracy best_accuracies_each_c = val_accuracies_each_c best_mean_val_accuracies = mean_val_accuracies best_total_val_accuracies = total_val_accuracies options = io_utils.get_model_options_from_args(args, i) # dict to save model_dict = {'model': model.cpu().state_dict()} optimizer_dict = {'optimizer': optimizer.state_dict()} # save best checkpoint io_utils.save_checkpoints(args.save_dir, options, i, model_dict, optimizer_dict, logger, best=True) # ship to cuda model = model.cuda(args.gpu) # save best result into textfile contents = [' '.join(sys.argv) + '\n', "best accuracy: {:.2f}%\n".format(best_accuracy)] best_accuracy_each_c = best_accuracies_each_c[0] best_mean_val_accuracy = best_mean_val_accuracies[0] best_total_val_accuracy = best_total_val_accuracies[0] contents.extend(["{}\n".format(args.finetune_dataset), "best total acc: {:.2f}%\n".format(100 * best_total_val_accuracy), "best mean acc: {:.2f}%\n".format(100 * best_mean_val_accuracy), 'Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc) for _, c_val_acc in best_accuracy_each_c]) + '\n']) best_result_path = os.path.join('./output', '{}_best_result.txt'.format( os.path.splitext(os.path.basename(__file__))[0])) with open(best_result_path, 'a+') as f: f.writelines(contents) val_dataloader_iter = enumerate(val_dataloader) pred_vals = [] y_vals = [] x_val = None y_val = None # print('------------------------dataload------------------------') with torch.no_grad(): for j, (x_val, y_val) in val_dataloader_iter: y_vals.append(y_val.cpu()) x_val = x_val.cuda(args.gpu) y_val = y_val.cuda(args.gpu) pred_val = model(x_val, 0 * torch.ones_like(y_val), with_ft=False) pred_vals.append(pred_val.cpu()) pred_vals = torch.cat(pred_vals, 0) y_vals = torch.cat(y_vals, 0) total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0]) val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals, class_idx=c, topk=(1,))[0])) for c, c_name in enumerate(val_dataset.classes)] for cls in val_accuracy_each_c: print(cls) print(total_val_accuracy)