def main(): # Training settings parser = argparse.ArgumentParser( description='PyTorch classification example') parser.add_argument('--dataset', type=str, help='dataset', choices=[ 'mnist', 'usps', 'svhn', 'syn_digits', 'imagenet32x32', 'cifar10', 'stl10', ]) parser.add_argument('--arch', type=str, help='network architecture') parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)') parser.add_argument('--test_batch_size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--val_ratio', type=float, default=0.0, help='sampling ratio of validation data') parser.add_argument('--train_ratio', type=float, default=1.0, help='sampling ratio of training data') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--wd', type=float, default=1e-6, help='weight_decay (default: 1e-6)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--log_interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--output_path', type=str, help='path to save ckpt and log. ') parser.add_argument('--resume', type=str, help='resume training from ckpt path') parser.add_argument('--ckpt_file', type=str, help='init model from ckpt. ') parser.add_argument( '--exclude_vars', type=str, help= 'prefix of variables not restored form ckpt, seperated with commas; valid if ckpt_file is not None' ) parser.add_argument('--imagenet_pretrain', action='store_true', help='use pretrained imagenet model') args = parser.parse_args() use_cuda = torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") if args.output_path is not None and not os.path.exists(args.output_path): os.makedirs(args.output_path) writer = SummaryWriter(args.output_path) use_normalize = True if args.dataset == 'imagenet32x32': n_classes = 1000 args.batch_size = 256 elif args.dataset in ["cifar10", "stl10"]: n_classes = 9 elif args.dataset in ["usps", "mnist", "svhn", 'syn_digits']: n_classes = 10 else: raise ValueError('invalid dataset option: {}'.format(args.dataset)) kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {} assert (args.val_ratio >= 0. and args.val_ratio < 1.) assert (args.train_ratio > 0. and args.train_ratio <= 1.) train_ds = get_dataset(args.dataset, 'train', use_normalize=use_normalize, test_size=args.val_ratio, train_size=args.train_ratio) train_loader = torch.utils.data.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(get_dataset( args.dataset, 'test', use_normalize=use_normalize, test_size=args.val_ratio, train_size=args.train_ratio), batch_size=args.test_batch_size, shuffle=True, **kwargs) if args.val_ratio == 0.0: val_loader = test_loader else: val_ds = get_dataset(args.dataset, 'val', use_normalize=use_normalize, test_size=args.val_ratio, train_size=args.train_ratio) val_loader = torch.utils.data.DataLoader(val_ds, batch_size=args.batch_size, shuffle=True, **kwargs) if args.arch == "DTN": model = network.DTN().to(device) elif args.arch == 'wrn': model = network.WideResNet(depth=28, num_classes=n_classes, widen_factor=10, dropRate=0.0).to(device) else: raise ValueError('invalid network architecture {}'.format(args.arch)) if args.ckpt_file is not None: print('initialize model parameters from {}'.format(args.ckpt_file)) model.restore_from_ckpt(torch.load(args.ckpt_file, map_location='cpu'), exclude_vars=args.exclude_vars.split(',') if args.exclude_vars is not None else []) print('accuracy on test set before fine-tuning') test(args, model, device, test_loader) if args.resume is not None: assert (os.path.isfile(args.resume)) print('resume training from {}'.format(args.resume)) model.load_state_dict(torch.load(args.resume)) if use_cuda: # model = torch.nn.DataParallel(model) cudnn.benchmark = True if args.dataset.startswith("cifar") or args.dataset in ['stl10']: lr_decay_step = 100 lr_decay_rate = 0.1 PATIENCE = 100 optimizer = optim.SGD(model.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd) scheduler = MultiStepLR(optimizer, milestones=[150, 250], gamma=0.1) elif args.dataset in ["mnist", "usps", "svhn", "syn_digits"]: lr_decay_step = 50 lr_decay_rate = 0.5 if args.dataset == 'svhn': PATIENCE = 10 else: PATIENCE = 50 optimizer = optim.SGD(model.get_parameters(args.lr), momentum=0.5, weight_decay=args.wd) scheduler = StepLR(optimizer, step_size=lr_decay_step, gamma=lr_decay_rate) elif args.dataset == 'imagenet32x32': PATIENCE = 10 lr_decay_step = 10 lr_decay_rate = 0.2 optimizer = torch.optim.SGD(model.get_parameters(args.lr), momentum=0.9, weight_decay=5e-4, nesterov=True) scheduler = StepLR(optimizer, step_size=lr_decay_step, gamma=lr_decay_rate) else: raise ValueError("invalid dataset option: {}".format(args.dataset)) early_stop_engine = EarlyStopping(PATIENCE) print("args:{}".format(args)) # start training. best_accuracy = 0. save_path = os.path.join(args.output_path, "model.pt") time_stats = [] for epoch in range(1, args.epochs + 1): start_time = time.time() train(args, model, device, train_loader, optimizer, epoch, writer) training_time = time.time() - start_time print('epoch: {} training time: {:.2f}'.format(epoch, training_time)) time_stats.append(training_time) val_accuracy = test(args, model, device, val_loader) scheduler.step() writer.add_scalar("val_accuracy", val_accuracy, epoch) if val_accuracy >= best_accuracy: best_accuracy = val_accuracy torch.save(model.state_dict(), save_path) if epoch % 20 == 0: print('accuracy on test set at epoch {}'.format(epoch)) test(args, model, device, test_loader) if early_stop_engine.is_stop_training(val_accuracy): print( "no improvement after {}, stop training at epoch {}\n".format( PATIENCE, epoch)) break # print('finish training {} epochs'.format(args.epochs)) mean_training_time = np.mean(np.array(time_stats)) print('Average training_time: {}'.format(mean_training_time)) print('load ckpt with best validation accuracy from {}'.format(save_path)) model.load_state_dict(torch.load(save_path, map_location='cpu')) test_accuracy = test(args, model, device, test_loader) writer.add_scalar("test_accuracy", test_accuracy, args.epochs) with open(os.path.join(args.output_path, 'accuracy.pkl'), 'wb') as pkl_file: pkl.dump( { 'train': best_accuracy, 'test': test_accuracy, 'training_time': mean_training_time }, pkl_file)
def train(config): ## set up summary writer writer = SummaryWriter(config['output_path']) # set up early stop early_stop_engine = EarlyStopping(config["early_stop_patience"]) ## set pre-process prep_dict = {} prep_config = config["prep"] prep_dict["source"] = prep.image_train( \ resize_size=prep_config["resize_size"], \ crop_size=prep_config["crop_size"]) prep_dict["target"] = prep.image_train( \ resize_size=prep_config["resize_size"], \ crop_size=prep_config["crop_size"]) if prep_config["test_10crop"]: prep_dict["test"] = prep.image_test_10crop( \ resize_size=prep_config["resize_size"], \ crop_size=prep_config["crop_size"]) else: prep_dict["test"] = prep.image_test( \ resize_size=prep_config["resize_size"], \ crop_size=prep_config["crop_size"]) ## set loss class_num = config["network"]["params"]["class_num"] loss_params = config["loss"] class_criterion = nn.CrossEntropyLoss() transfer_criterion = loss.PADA center_criterion = loss_params["loss_type"](num_classes=class_num, feat_dim=config["network"]["params"]["bottleneck_dim"]) ## prepare data dsets = {} dset_loaders = {} data_config = config["data"] dsets["source"] = ImageList(stratify_sampling(open(data_config["source"]["list_path"]).readlines(), prep_config["source_size"]), \ transform=prep_dict["source"]) dset_loaders["source"] = util_data.DataLoader(dsets["source"], \ batch_size=data_config["source"]["batch_size"], \ shuffle=True, num_workers=1) dsets["target"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), prep_config["target_size"]), \ transform=prep_dict["target"]) dset_loaders["target"] = util_data.DataLoader(dsets["target"], \ batch_size=data_config["target"]["batch_size"], \ shuffle=True, num_workers=1) if prep_config["test_10crop"]: for i in range(10): dsets["test"+str(i)] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \ transform=prep_dict["test"]["val"+str(i)]) dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \ batch_size=data_config["test"]["batch_size"], \ shuffle=False, num_workers=1) dsets["target"+str(i)] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \ transform=prep_dict["test"]["val"+str(i)]) dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \ batch_size=data_config["test"]["batch_size"], \ shuffle=False, num_workers=1) else: dsets["test"] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \ transform=prep_dict["test"]) dset_loaders["test"] = util_data.DataLoader(dsets["test"], \ batch_size=data_config["test"]["batch_size"], \ shuffle=False, num_workers=1) dsets["target_test"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \ transform=prep_dict["test"]) dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \ batch_size=data_config["test"]["batch_size"], \ shuffle=False, num_workers=1) config['out_file'].write("dataset sizes: source={}, target={}\n".format( len(dsets["source"]), len(dsets["target"]))) ## set base network net_config = config["network"] base_network = net_config["name"](**net_config["params"]) use_gpu = torch.cuda.is_available() if use_gpu: base_network = base_network.cuda() ## collect parameters if net_config["params"]["new_cls"]: if net_config["params"]["use_bottleneck"]: parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \ {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \ {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}] else: parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \ {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}] else: parameter_list = [{"params":base_network.parameters(), "lr_mult":1, 'decay_mult':2}] ## add additional network for some methods ad_net = network.AdversarialNetwork(base_network.output_num()) gradient_reverse_layer = network.AdversarialLayer(high_value=config["high"]) #, #max_iter_value=config["num_iterations"]) if use_gpu: ad_net = ad_net.cuda() parameter_list.append({"params":ad_net.parameters(), "lr_mult":10, 'decay_mult':2}) parameter_list.append({"params":center_criterion.parameters(), "lr_mult": 10, 'decay_mult':1}) ## set optimizer optimizer_config = config["optimizer"] optimizer = optim_dict[optimizer_config["type"]](parameter_list, \ **(optimizer_config["optim_params"])) param_lr = [] for param_group in optimizer.param_groups: param_lr.append(param_group["lr"]) schedule_param = optimizer_config["lr_param"] lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]] ## train len_train_source = len(dset_loaders["source"]) - 1 len_train_target = len(dset_loaders["target"]) - 1 transfer_loss_value = classifier_loss_value = total_loss_value = 0.0 best_acc = 0.0 for i in range(config["num_iterations"]): if i % config["test_interval"] == 0: base_network.train(False) if config['loss']['ly_type'] == "cosine": temp_acc = image_classification_test(dset_loaders, \ base_network, test_10crop=prep_config["test_10crop"], \ gpu=use_gpu) elif config['loss']['ly_type'] == "euclidean": temp_acc, _ = distance_classification_test(dset_loaders, \ base_network, center_criterion.centers.detach(), test_10crop=prep_config["test_10crop"], \ gpu=use_gpu) else: raise ValueError("no test method for cls loss: {}".format(config['loss']['ly_type'])) snapshot_obj = {'step': i, "base_network": base_network.state_dict(), 'precision': temp_acc, } if config["loss"]["loss_name"] != "laplacian" and config["loss"]["ly_type"] == "euclidean": snapshot_obj['center_criterion'] = center_criterion.state_dict() if temp_acc > best_acc: best_acc = temp_acc # save best model torch.save(snapshot_obj, osp.join(config["output_path"], "best_model.pth.tar")) log_str = "iter: {:05d}, {} precision: {:.5f}\n".format(i, config['loss']['ly_type'], temp_acc) config["out_file"].write(log_str) config["out_file"].flush() writer.add_scalar("precision", temp_acc, i) if early_stop_engine.is_stop_training(temp_acc): config["out_file"].write("no improvement after {}, stop training at step {}\n".format( config["early_stop_patience"], i)) # config["out_file"].write("finish training! \n") break if (i+1) % config["snapshot_interval"] == 0: torch.save(snapshot_obj, osp.join(config["output_path"], "iter_{:05d}_model.pth.tar".format(i))) ## train one iter base_network.train(True) optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param) optimizer.zero_grad() if i % len_train_source == 0: iter_source = iter(dset_loaders["source"]) if i % len_train_target == 0: iter_target = iter(dset_loaders["target"]) inputs_source, labels_source = iter_source.next() inputs_target, labels_target = iter_target.next() if use_gpu: inputs_source, inputs_target, labels_source = \ Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \ Variable(labels_source).cuda() else: inputs_source, inputs_target, labels_source = Variable(inputs_source), \ Variable(inputs_target), Variable(labels_source) inputs = torch.cat((inputs_source, inputs_target), dim=0) source_batch_size = inputs_source.size(0) if config['loss']['ly_type'] == 'cosine': features, logits = base_network(inputs) source_logits = logits.narrow(0, 0, source_batch_size) elif config['loss']['ly_type'] == 'euclidean': features, _ = base_network(inputs) logits = -1.0 * loss.distance_to_centroids(features, center_criterion.centers.detach()) source_logits = logits.narrow(0, 0, source_batch_size) ad_net.train(True) weight_ad = torch.ones(inputs.size(0)) transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \ weight_ad, use_gpu) ad_out, _ = ad_net(features.detach()) ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out) # source domain classification task loss classifier_loss = class_criterion(source_logits, labels_source) # fisher loss on labeled source domain fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(features.narrow(0, 0, int(inputs.size(0)/2)), labels_source, inter_class=config["loss"]["inter_type"], intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"]) # entropy minimization loss em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits)) # final loss total_loss = loss_params["trade_off"] * transfer_loss \ + fisher_loss \ + loss_params["em_loss_coef"] * em_loss \ + classifier_loss total_loss.backward() if center_grad is not None: # clear mmc_loss center_criterion.centers.grad.zero_() # Manually assign centers gradients other than using autograd center_criterion.centers.backward(center_grad) optimizer.step() if i % config["log_iter"] == 0: config['out_file'].write('iter {}: total loss={:0.4f}, transfer loss={:0.4f}, cls loss={:0.4f}, ' 'em loss={:0.4f}, ' 'mmc loss={:0.4f}, intra loss={:0.4f}, inter loss={:0.4f}, ' 'ad acc={:0.4f}, source_acc={:0.4f}, target_acc={:0.4f}\n'.format( i, total_loss.data.cpu().float().item(), transfer_loss.data.cpu().float().item(), classifier_loss.data.cpu().float().item(), em_loss.data.cpu().float().item(), fisher_loss.cpu().float().item(), fisher_intra_loss.cpu().float().item(), fisher_inter_loss.cpu().float().item(), ad_acc, source_acc_ad, target_acc_ad, )) config['out_file'].flush() writer.add_scalar("total_loss", total_loss.data.cpu().float().item(), i) writer.add_scalar("cls_loss", classifier_loss.data.cpu().float().item(), i) writer.add_scalar("transfer_loss", transfer_loss.data.cpu().float().item(), i) writer.add_scalar("ad_acc", ad_acc, i) writer.add_scalar("d_loss/total", fisher_loss.data.cpu().float().item(), i) writer.add_scalar("d_loss/intra", fisher_intra_loss.data.cpu().float().item(), i) writer.add_scalar("d_loss/inter", fisher_inter_loss.data.cpu().float().item(), i) return best_acc