def main(): # Training settings def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Unsupported value encountered.') parser = argparse.ArgumentParser(description='ALDA USPS2MNIST') parser.add_argument('method', type=str, default='ALDA', choices=['DANN', "ALDA"]) parser.add_argument('--task', default='MNIST2USPS', help='task to perform') parser.add_argument('--batch_size', type=int, default=64, help='input batch size for training (default: 64)') parser.add_argument('--test_batch_size', type=int, default=1000, help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=2e-4, metavar='LR', help='learning rate (default: 2e-4)') parser.add_argument('--gpu_id', type=str, help='cuda device id') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--log_interval', type=int, default=500, help='how many batches to wait before logging training status') parser.add_argument('--trade_off', type=float, default=1.0, help="trade_off") parser.add_argument('--start_epoch', type=int, default=0, help="begin adaptation after start_epoch") parser.add_argument('--threshold', default=0.9, type=float, help="threshold of pseudo labels") parser.add_argument( '--output_dir', type=str, default=None, help="output directory of our model (in ../snapshot directory)") parser.add_argument('--loss_type', type=str, default='all', help="whether add reg_loss or correct_loss.") parser.add_argument('--cos_dist', type=str2bool, default=False, help="the classifier uses cosine similarity.") parser.add_argument('--num_worker', type=int, default=4) args = parser.parse_args() torch.manual_seed(args.seed) network.set_device(args.gpu_id) if args.task == 'USPS2MNIST': source_list = './data/usps2mnist/usps_train.txt' target_list = './data/usps2mnist/mnist_train.txt' test_list = './data/usps2mnist/mnist_test.txt' start_epoch = 1 decay_epoch = 6 elif args.task == 'MNIST2USPS': source_list = './data/usps2mnist/mnist_train.txt' target_list = './data/usps2mnist/usps_train.txt' test_list = './data/usps2mnist/usps_test.txt' start_epoch = 1 decay_epoch = 5 else: raise Exception('task cannot be recognized!') source_list = open(source_list).readlines() target_list = open(target_list).readlines() test_list = open(test_list).readlines() train_loader = torch.utils.data.DataLoader(ImageList( source_list, transform=transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ]), mode='L'), batch_size=args.batch_size, shuffle=True, num_workers=args.num_worker, drop_last=True, pin_memory=True) train_loader1 = torch.utils.data.DataLoader(ImageList( target_list, transform=transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ]), mode='L'), batch_size=args.batch_size, shuffle=True, num_workers=args.num_worker, drop_last=True, pin_memory=True) test_loader = torch.utils.data.DataLoader(ImageList( test_list, transform=transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ]), mode='L'), batch_size=args.test_batch_size, shuffle=False, num_workers=args.num_worker, pin_memory=True) model = network.USPS_EnsembNet() model = model.to(network.dev) class_num = 10 random_layer = None if args.method == "ALDA": ad_net = network.Multi_AdversarialNetwork(model.output_num(), 500, class_num) elif args.method == "DANN": ad_net = network.AdversarialNetwork(model.output_num(), 500) ad_net = ad_net.to(network.dev) if args.task == 'USPS2MNIST': args.lr = 2e-4 else: args.lr = 1e-3 optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005) optimizer_ad = optim.Adam(ad_net.parameters(), lr=args.lr, weight_decay=0.0005) start_epoch = args.start_epoch if args.output_dir is None: args.output_dir = args.task.lower() + '_' + args.method output_path = "snapshot/" + args.output_dir if os.path.exists(output_path): print("checkpoint dir exists, which will be removed") import shutil shutil.rmtree(output_path, ignore_errors=True) os.mkdir(output_path) for epoch in range(1, args.epochs + 1): if epoch % decay_epoch == 0: for param_group in optimizer.param_groups: param_group["lr"] = param_group["lr"] * 0.5 train(args, model, ad_net, train_loader, train_loader1, optimizer, optimizer_ad, epoch, start_epoch, args.method) test(args, model, test_loader) if epoch % 5 == 1: torch.save(model.state_dict(), osp.join(output_path, "epoch_{}.pth".format(epoch)))
def train(config): ## set pre-process prep_dict = {} prep_config = config["prep"] prep_dict["source"] = prep.image_train(**config["prep"]['params']) prep_dict["target"] = prep.image_train(**config["prep"]['params']) if prep_config["test_10crop"]: prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params']) else: prep_dict["test"] = prep.image_test(**config["prep"]['params']) ## prepare data dsets = {} dset_loaders = {} data_config = config["data"] train_bs = data_config["source"]["batch_size"] test_bs = data_config["test"]["batch_size"] source_list = [ '.' + i for i in open(data_config["source"]["list_path"]).readlines() ] target_list = [ '.' + i for i in open(data_config["target"]["list_path"]).readlines() ] dsets["source"] = ImageList(source_list, \ transform=prep_dict["source"]) dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \ shuffle=True, num_workers=config['args'].num_worker, drop_last=True) dsets["target"] = ImageList(target_list, \ transform=prep_dict["target"]) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \ shuffle=True, num_workers=config['args'].num_worker, drop_last=True) print("source dataset len:", len(dsets["source"])) print("target dataset len:", len(dsets["target"])) if prep_config["test_10crop"]: for i in range(10): test_list = [ '.' + i for i in open(data_config["test"]["list_path"]).readlines() ] dsets["test"] = [ImageList(test_list, \ transform=prep_dict["test"][i]) for i in range(10)] dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \ shuffle=False, num_workers=config['args'].num_worker) for dset in dsets['test']] else: test_list = [ '.' + i for i in open(data_config["test"]["list_path"]).readlines() ] dsets["test"] = ImageList(test_list, \ transform=prep_dict["test"]) dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \ shuffle=False, num_workers=config['args'].num_worker) dsets["target_label"] = ImageList_label(target_list, \ transform=prep_dict["target"]) dset_loaders["target_label"] = DataLoader(dsets["target_label"], batch_size=test_bs, \ shuffle=False, num_workers=config['args'].num_worker, drop_last=False) class_num = config["network"]["params"]["class_num"] ## set base network net_config = config["network"] base_network = net_config["name"](**net_config["params"]) base_network = base_network.to(network.dev) if config["restore_path"]: checkpoint = torch.load( osp.join(config["restore_path"], "best_model.pth"))["base_network"] ckp = {} for k, v in checkpoint.items(): if "module" in k: ckp[k.split("module.")[-1]] = v else: ckp[k] = v base_network.load_state_dict(ckp) log_str = "successfully restore from {}".format( osp.join(config["restore_path"], "best_model.pth")) config["out_file"].write(log_str + "\n") config["out_file"].flush() print(log_str) ## add additional network for some methods if "ALDA" in args.method: ad_net = network.Multi_AdversarialNetwork(base_network.output_num(), 1024, class_num) else: ad_net = network.AdversarialNetwork(base_network.output_num(), 1024) ad_net = ad_net.to(network.dev) parameter_list = base_network.get_parameters() + ad_net.get_parameters() ## set optimizer optimizer_config = config["optimizer"] optimizer = optimizer_config["type"](parameter_list, \ **(optimizer_config["optim_params"])) param_lr = [] for param_group in optimizer.param_groups: param_lr.append(param_group["lr"]) schedule_param = optimizer_config["lr_param"] lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]] gpus = config['gpu'].split(',') if len(gpus) > 1: ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in range(len(gpus))]) base_network = nn.DataParallel( base_network, device_ids=[int(i) for i in range(len(gpus))]) loss_params = config["loss"] high = loss_params["trade_off"] begin_label = False writer = SummaryWriter(config["output_path"]) ## train len_train_source = len(dset_loaders["source"]) len_train_target = len(dset_loaders["target"]) transfer_loss_value = classifier_loss_value = total_loss_value = 0.0 best_acc = 0.0 loss_value = 0 loss_adv_value = 0 loss_correct_value = 0 for i in tqdm(range(config["num_iterations"]), total=config["num_iterations"]): if i % config["test_interval"] == config["test_interval"] - 1: base_network.train(False) temp_acc = image_classification_test(dset_loaders, \ base_network, test_10crop=prep_config["test_10crop"]) temp_model = base_network #nn.Sequential(base_network) if temp_acc > best_acc: best_step = i best_acc = temp_acc best_model = temp_model checkpoint = { "base_network": best_model.state_dict(), "ad_net": ad_net.state_dict() } torch.save(checkpoint, osp.join(config["output_path"], "best_model.pth")) print( "\n########## save the best model. #############\n") log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc) config["out_file"].write(log_str + "\n") config["out_file"].flush() writer.add_scalar('precision', temp_acc, i) print(log_str) print("adv_loss: {:.3f} correct_loss: {:.3f} class_loss: {:.3f}". format(loss_adv_value, loss_correct_value, loss_value)) loss_value = 0 loss_adv_value = 0 loss_correct_value = 0 #show val result on tensorboard images_inv = prep.inv_preprocess(inputs_source.clone().cpu(), 3) for index, img in enumerate(images_inv): writer.add_image(str(index) + '/Images', img, i) # save the pseudo_label if 'PseudoLabel' in config['method'] and ( i % config["label_interval"] == config["label_interval"] - 1): base_network.train(False) pseudo_label_list = image_label(dset_loaders, base_network, threshold=config['threshold'], \ out_dir=config["output_path"]) dsets["target"] = ImageList(open(pseudo_label_list).readlines(), \ transform=prep_dict["target"]) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \ shuffle=True, num_workers=config['args'].num_worker, drop_last=True) iter_target = iter( dset_loaders["target"] ) # replace the target dataloader with Pseudo_Label dataloader begin_label = True if i > config["stop_step"]: log_str = "method {}, iter: {:05d}, precision: {:.5f}".format( config["output_path"], best_step, best_acc) config["final_log"].write(log_str + "\n") config["final_log"].flush() break ## train one iter base_network.train(True) ad_net.train(True) optimizer = lr_scheduler(optimizer, i, **schedule_param) optimizer.zero_grad() if i % len_train_source == 0: iter_source = iter(dset_loaders["source"]) if i % len_train_target == 0: iter_target = iter(dset_loaders["target"]) inputs_source, labels_source = iter_source.next() inputs_target, labels_target = iter_target.next() inputs_source, inputs_target, labels_source = Variable( inputs_source).to(network.dev), Variable(inputs_target).to( network.dev), Variable(labels_source).to(network.dev) features_source, outputs_source = base_network(inputs_source) if args.source_detach: features_source = features_source.detach() features_target, outputs_target = base_network(inputs_target) features = torch.cat((features_source, features_target), dim=0) outputs = torch.cat((outputs_source, outputs_target), dim=0) softmax_out = nn.Softmax(dim=1)(outputs) loss_params["trade_off"] = network.calc_coeff( i, high=high) #if i > 500 else 0.0 transfer_loss = 0.0 if 'DANN' in config['method']: transfer_loss = loss.DANN(features, ad_net) elif "ALDA" in config['method']: ad_out = ad_net(features) adv_loss, reg_loss, correct_loss = loss.ALDA_loss( ad_out, labels_source, softmax_out, weight_type=config['args'].weight_type, threshold=config['threshold']) # whether add the corrected self-training loss if "nocorrect" in config['args'].loss_type: transfer_loss = adv_loss else: transfer_loss = config['args'].adv_weight * adv_loss + config[ 'args'].adv_weight * loss_params["trade_off"] * correct_loss # reg_loss is only backward to the discriminator if "noreg" not in config['args'].loss_type: for param in base_network.parameters(): param.requires_grad = False reg_loss.backward(retain_graph=True) for param in base_network.parameters(): param.requires_grad = True # on-line self-training elif 'SelfTraining' in config['method']: transfer_loss += loss_params["trade_off"] * loss.SelfTraining_loss( outputs, softmax_out, config['threshold']) # off-line self-training elif 'PseudoLabel' in config['method']: labels_target = labels_target.to(network.dev) if begin_label: transfer_loss += loss_params["trade_off"] * nn.CrossEntropyLoss( ignore_index=-1)(outputs_target, labels_target) else: transfer_loss += 0.0 * nn.CrossEntropyLoss(ignore_index=-1)( outputs_target, labels_target) classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) loss_value += classifier_loss.item() / config["test_interval"] loss_adv_value += adv_loss.item() / config["test_interval"] loss_correct_value += correct_loss.item() / config["test_interval"] total_loss = classifier_loss + transfer_loss total_loss.backward() optimizer.step() checkpoint = { "base_network": temp_model.state_dict(), "ad_net": ad_net.state_dict() } torch.save(checkpoint, osp.join(config["output_path"], "final_model.pth")) return best_acc