예제 #1
0
def main():
    # Training settings
    def str2bool(v):
        if v.lower() in ('yes', 'true', 't', 'y', '1'):
            return True
        elif v.lower() in ('no', 'false', 'f', 'n', '0'):
            return False
        else:
            raise argparse.ArgumentTypeError('Unsupported value encountered.')

    parser = argparse.ArgumentParser(description='ALDA USPS2MNIST')
    parser.add_argument('method',
                        type=str,
                        default='ALDA',
                        choices=['DANN', "ALDA"])
    parser.add_argument('--task', default='MNIST2USPS', help='task to perform')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=2e-4,
                        metavar='LR',
                        help='learning rate (default: 2e-4)')
    parser.add_argument('--gpu_id', type=str, help='cuda device id')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=500,
        help='how many batches to wait before logging training status')
    parser.add_argument('--trade_off',
                        type=float,
                        default=1.0,
                        help="trade_off")
    parser.add_argument('--start_epoch',
                        type=int,
                        default=0,
                        help="begin adaptation after start_epoch")
    parser.add_argument('--threshold',
                        default=0.9,
                        type=float,
                        help="threshold of pseudo labels")
    parser.add_argument(
        '--output_dir',
        type=str,
        default=None,
        help="output directory of our model (in ../snapshot directory)")
    parser.add_argument('--loss_type',
                        type=str,
                        default='all',
                        help="whether add reg_loss or correct_loss.")
    parser.add_argument('--cos_dist',
                        type=str2bool,
                        default=False,
                        help="the classifier uses cosine similarity.")
    parser.add_argument('--num_worker', type=int, default=4)
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    network.set_device(args.gpu_id)

    if args.task == 'USPS2MNIST':
        source_list = './data/usps2mnist/usps_train.txt'
        target_list = './data/usps2mnist/mnist_train.txt'
        test_list = './data/usps2mnist/mnist_test.txt'
        start_epoch = 1
        decay_epoch = 6
    elif args.task == 'MNIST2USPS':
        source_list = './data/usps2mnist/mnist_train.txt'
        target_list = './data/usps2mnist/usps_train.txt'
        test_list = './data/usps2mnist/usps_test.txt'
        start_epoch = 1
        decay_epoch = 5
    else:
        raise Exception('task cannot be recognized!')

    source_list = open(source_list).readlines()
    target_list = open(target_list).readlines()
    test_list = open(test_list).readlines()

    train_loader = torch.utils.data.DataLoader(ImageList(
        source_list,
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_worker,
                                               drop_last=True,
                                               pin_memory=True)
    train_loader1 = torch.utils.data.DataLoader(ImageList(
        target_list,
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_worker,
                                                drop_last=True,
                                                pin_memory=True)
    test_loader = torch.utils.data.DataLoader(ImageList(
        test_list,
        transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]),
        mode='L'),
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              num_workers=args.num_worker,
                                              pin_memory=True)

    model = network.USPS_EnsembNet()
    model = model.to(network.dev)
    class_num = 10

    random_layer = None
    if args.method == "ALDA":
        ad_net = network.Multi_AdversarialNetwork(model.output_num(), 500,
                                                  class_num)
    elif args.method == "DANN":
        ad_net = network.AdversarialNetwork(model.output_num(), 500)
    ad_net = ad_net.to(network.dev)
    if args.task == 'USPS2MNIST':
        args.lr = 2e-4
    else:
        args.lr = 1e-3
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005)
    optimizer_ad = optim.Adam(ad_net.parameters(),
                              lr=args.lr,
                              weight_decay=0.0005)

    start_epoch = args.start_epoch
    if args.output_dir is None:
        args.output_dir = args.task.lower() + '_' + args.method
    output_path = "snapshot/" + args.output_dir
    if os.path.exists(output_path):
        print("checkpoint dir exists, which will be removed")
        import shutil
        shutil.rmtree(output_path, ignore_errors=True)
    os.mkdir(output_path)

    for epoch in range(1, args.epochs + 1):
        if epoch % decay_epoch == 0:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] * 0.5
        train(args, model, ad_net, train_loader, train_loader1, optimizer,
              optimizer_ad, epoch, start_epoch, args.method)
        test(args, model, test_loader)
        if epoch % 5 == 1:
            torch.save(model.state_dict(),
                       osp.join(output_path, "epoch_{}.pth".format(epoch)))
예제 #2
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]

    source_list = [
        '.' + i for i in open(data_config["source"]["list_path"]).readlines()
    ]
    target_list = [
        '.' + i for i in open(data_config["target"]["list_path"]).readlines()
    ]

    dsets["source"] = ImageList(source_list, \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
            shuffle=True, num_workers=config['args'].num_worker, drop_last=True)
    dsets["target"] = ImageList(target_list, \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
            shuffle=True, num_workers=config['args'].num_worker, drop_last=True)
    print("source dataset len:", len(dsets["source"]))
    print("target dataset len:", len(dsets["target"]))

    if prep_config["test_10crop"]:
        for i in range(10):
            test_list = [
                '.' + i
                for i in open(data_config["test"]["list_path"]).readlines()
            ]
            dsets["test"] = [ImageList(test_list, \
                                transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                shuffle=False, num_workers=config['args'].num_worker) for dset in dsets['test']]
    else:
        test_list = [
            '.' + i
            for i in open(data_config["test"]["list_path"]).readlines()
        ]
        dsets["test"] = ImageList(test_list, \
                                transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                shuffle=False, num_workers=config['args'].num_worker)

    dsets["target_label"] = ImageList_label(target_list, \
                            transform=prep_dict["target"])
    dset_loaders["target_label"] = DataLoader(dsets["target_label"], batch_size=test_bs, \
            shuffle=False, num_workers=config['args'].num_worker, drop_last=False)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.to(network.dev)
    if config["restore_path"]:
        checkpoint = torch.load(
            osp.join(config["restore_path"], "best_model.pth"))["base_network"]
        ckp = {}
        for k, v in checkpoint.items():
            if "module" in k:
                ckp[k.split("module.")[-1]] = v
            else:
                ckp[k] = v
        base_network.load_state_dict(ckp)
        log_str = "successfully restore from {}".format(
            osp.join(config["restore_path"], "best_model.pth"))
        config["out_file"].write(log_str + "\n")
        config["out_file"].flush()
        print(log_str)

    ## add additional network for some methods
    if "ALDA" in args.method:
        ad_net = network.Multi_AdversarialNetwork(base_network.output_num(),
                                                  1024, class_num)
    else:
        ad_net = network.AdversarialNetwork(base_network.output_num(), 1024)
    ad_net = ad_net.to(network.dev)
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net,
                                 device_ids=[int(i) for i in range(len(gpus))])
        base_network = nn.DataParallel(
            base_network, device_ids=[int(i) for i in range(len(gpus))])

    loss_params = config["loss"]
    high = loss_params["trade_off"]
    begin_label = False
    writer = SummaryWriter(config["output_path"])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    loss_value = 0
    loss_adv_value = 0
    loss_correct_value = 0
    for i in tqdm(range(config["num_iterations"]),
                  total=config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                base_network, test_10crop=prep_config["test_10crop"])
            temp_model = base_network  #nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_step = i
                best_acc = temp_acc
                best_model = temp_model
                checkpoint = {
                    "base_network": best_model.state_dict(),
                    "ad_net": ad_net.state_dict()
                }
                torch.save(checkpoint,
                           osp.join(config["output_path"], "best_model.pth"))
                print(
                    "\n##########     save the best model.    #############\n")
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            writer.add_scalar('precision', temp_acc, i)
            print(log_str)

            print("adv_loss: {:.3f} correct_loss: {:.3f} class_loss: {:.3f}".
                  format(loss_adv_value, loss_correct_value, loss_value))
            loss_value = 0
            loss_adv_value = 0
            loss_correct_value = 0

            #show val result on tensorboard
            images_inv = prep.inv_preprocess(inputs_source.clone().cpu(), 3)
            for index, img in enumerate(images_inv):
                writer.add_image(str(index) + '/Images', img, i)

        # save the pseudo_label
        if 'PseudoLabel' in config['method'] and (
                i % config["label_interval"] == config["label_interval"] - 1):
            base_network.train(False)
            pseudo_label_list = image_label(dset_loaders, base_network, threshold=config['threshold'], \
                                out_dir=config["output_path"])
            dsets["target"] = ImageList(open(pseudo_label_list).readlines(), \
                                transform=prep_dict["target"])
            dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                    shuffle=True, num_workers=config['args'].num_worker, drop_last=True)
            iter_target = iter(
                dset_loaders["target"]
            )  # replace the target dataloader with Pseudo_Label dataloader
            begin_label = True

        if i > config["stop_step"]:
            log_str = "method {}, iter: {:05d}, precision: {:.5f}".format(
                config["output_path"], best_step, best_acc)
            config["final_log"].write(log_str + "\n")
            config["final_log"].flush()
            break

        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = Variable(
            inputs_source).to(network.dev), Variable(inputs_target).to(
                network.dev), Variable(labels_source).to(network.dev)
        features_source, outputs_source = base_network(inputs_source)
        if args.source_detach:
            features_source = features_source.detach()
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)
        loss_params["trade_off"] = network.calc_coeff(
            i, high=high)  #if i > 500 else 0.0
        transfer_loss = 0.0
        if 'DANN' in config['method']:
            transfer_loss = loss.DANN(features, ad_net)
        elif "ALDA" in config['method']:
            ad_out = ad_net(features)
            adv_loss, reg_loss, correct_loss = loss.ALDA_loss(
                ad_out,
                labels_source,
                softmax_out,
                weight_type=config['args'].weight_type,
                threshold=config['threshold'])
            # whether add the corrected self-training loss
            if "nocorrect" in config['args'].loss_type:
                transfer_loss = adv_loss
            else:
                transfer_loss = config['args'].adv_weight * adv_loss + config[
                    'args'].adv_weight * loss_params["trade_off"] * correct_loss
            # reg_loss is only backward to the discriminator
            if "noreg" not in config['args'].loss_type:
                for param in base_network.parameters():
                    param.requires_grad = False
                reg_loss.backward(retain_graph=True)
                for param in base_network.parameters():
                    param.requires_grad = True
        # on-line self-training
        elif 'SelfTraining' in config['method']:
            transfer_loss += loss_params["trade_off"] * loss.SelfTraining_loss(
                outputs, softmax_out, config['threshold'])
        # off-line self-training
        elif 'PseudoLabel' in config['method']:
            labels_target = labels_target.to(network.dev)
            if begin_label:
                transfer_loss += loss_params["trade_off"] * nn.CrossEntropyLoss(
                    ignore_index=-1)(outputs_target, labels_target)
            else:
                transfer_loss += 0.0 * nn.CrossEntropyLoss(ignore_index=-1)(
                    outputs_target, labels_target)

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)
        loss_value += classifier_loss.item() / config["test_interval"]
        loss_adv_value += adv_loss.item() / config["test_interval"]
        loss_correct_value += correct_loss.item() / config["test_interval"]
        total_loss = classifier_loss + transfer_loss
        total_loss.backward()
        optimizer.step()
    checkpoint = {
        "base_network": temp_model.state_dict(),
        "ad_net": ad_net.state_dict()
    }
    torch.save(checkpoint, osp.join(config["output_path"], "final_model.pth"))
    return best_acc