コード例 #1
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='PyTorch classification example')
    parser.add_argument('--dataset',
                        type=str,
                        help='dataset',
                        choices=[
                            'mnist',
                            'usps',
                            'svhn',
                            'syn_digits',
                            'imagenet32x32',
                            'cifar10',
                            'stl10',
                        ])
    parser.add_argument('--arch', type=str, help='network architecture')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--val_ratio',
                        type=float,
                        default=0.0,
                        help='sampling ratio of validation data')
    parser.add_argument('--train_ratio',
                        type=float,
                        default=1.0,
                        help='sampling ratio of training data')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--wd',
                        type=float,
                        default=1e-6,
                        help='weight_decay (default: 1e-6)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--output_path',
                        type=str,
                        help='path to save ckpt and log. ')
    parser.add_argument('--resume',
                        type=str,
                        help='resume training from ckpt path')
    parser.add_argument('--ckpt_file', type=str, help='init model from ckpt. ')
    parser.add_argument(
        '--exclude_vars',
        type=str,
        help=
        'prefix of variables not restored form ckpt, seperated with commas; valid if ckpt_file is not None'
    )
    parser.add_argument('--imagenet_pretrain',
                        action='store_true',
                        help='use pretrained imagenet model')
    args = parser.parse_args()
    use_cuda = torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    if args.output_path is not None and not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    writer = SummaryWriter(args.output_path)

    use_normalize = True
    if args.dataset == 'imagenet32x32':
        n_classes = 1000
        args.batch_size = 256
    elif args.dataset in ["cifar10", "stl10"]:
        n_classes = 9
    elif args.dataset in ["usps", "mnist", "svhn", 'syn_digits']:
        n_classes = 10
    else:
        raise ValueError('invalid dataset option: {}'.format(args.dataset))

    kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}
    assert (args.val_ratio >= 0. and args.val_ratio < 1.)
    assert (args.train_ratio > 0. and args.train_ratio <= 1.)
    train_ds = get_dataset(args.dataset,
                           'train',
                           use_normalize=use_normalize,
                           test_size=args.val_ratio,
                           train_size=args.train_ratio)
    train_loader = torch.utils.data.DataLoader(train_ds,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(get_dataset(
        args.dataset,
        'test',
        use_normalize=use_normalize,
        test_size=args.val_ratio,
        train_size=args.train_ratio),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)
    if args.val_ratio == 0.0:
        val_loader = test_loader
    else:
        val_ds = get_dataset(args.dataset,
                             'val',
                             use_normalize=use_normalize,
                             test_size=args.val_ratio,
                             train_size=args.train_ratio)
        val_loader = torch.utils.data.DataLoader(val_ds,
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 **kwargs)

    if args.arch == "DTN":
        model = network.DTN().to(device)
    elif args.arch == 'wrn':
        model = network.WideResNet(depth=28,
                                   num_classes=n_classes,
                                   widen_factor=10,
                                   dropRate=0.0).to(device)
    else:
        raise ValueError('invalid network architecture {}'.format(args.arch))

    if args.ckpt_file is not None:
        print('initialize model parameters from {}'.format(args.ckpt_file))
        model.restore_from_ckpt(torch.load(args.ckpt_file, map_location='cpu'),
                                exclude_vars=args.exclude_vars.split(',')
                                if args.exclude_vars is not None else [])
        print('accuracy on test set before fine-tuning')
        test(args, model, device, test_loader)

    if args.resume is not None:
        assert (os.path.isfile(args.resume))
        print('resume training from {}'.format(args.resume))
        model.load_state_dict(torch.load(args.resume))

    if use_cuda:
        # model = torch.nn.DataParallel(model)
        cudnn.benchmark = True

    if args.dataset.startswith("cifar") or args.dataset in ['stl10']:
        lr_decay_step = 100
        lr_decay_rate = 0.1
        PATIENCE = 100
        optimizer = optim.SGD(model.get_parameters(args.lr),
                              momentum=args.momentum,
                              weight_decay=args.wd)
        scheduler = MultiStepLR(optimizer, milestones=[150, 250], gamma=0.1)
    elif args.dataset in ["mnist", "usps", "svhn", "syn_digits"]:
        lr_decay_step = 50
        lr_decay_rate = 0.5
        if args.dataset == 'svhn':
            PATIENCE = 10
        else:
            PATIENCE = 50
        optimizer = optim.SGD(model.get_parameters(args.lr),
                              momentum=0.5,
                              weight_decay=args.wd)
        scheduler = StepLR(optimizer,
                           step_size=lr_decay_step,
                           gamma=lr_decay_rate)
    elif args.dataset == 'imagenet32x32':
        PATIENCE = 10
        lr_decay_step = 10
        lr_decay_rate = 0.2
        optimizer = torch.optim.SGD(model.get_parameters(args.lr),
                                    momentum=0.9,
                                    weight_decay=5e-4,
                                    nesterov=True)
        scheduler = StepLR(optimizer,
                           step_size=lr_decay_step,
                           gamma=lr_decay_rate)
    else:
        raise ValueError("invalid dataset option: {}".format(args.dataset))

    early_stop_engine = EarlyStopping(PATIENCE)

    print("args:{}".format(args))

    # start training.
    best_accuracy = 0.
    save_path = os.path.join(args.output_path, "model.pt")
    time_stats = []
    for epoch in range(1, args.epochs + 1):
        start_time = time.time()
        train(args, model, device, train_loader, optimizer, epoch, writer)
        training_time = time.time() - start_time
        print('epoch: {} training time: {:.2f}'.format(epoch, training_time))
        time_stats.append(training_time)

        val_accuracy = test(args, model, device, val_loader)
        scheduler.step()

        writer.add_scalar("val_accuracy", val_accuracy, epoch)
        if val_accuracy >= best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), save_path)

        if epoch % 20 == 0:
            print('accuracy on test set at epoch {}'.format(epoch))
            test(args, model, device, test_loader)

        if early_stop_engine.is_stop_training(val_accuracy):
            print(
                "no improvement after {}, stop training at epoch {}\n".format(
                    PATIENCE, epoch))
            break

    # print('finish training {} epochs'.format(args.epochs))
    mean_training_time = np.mean(np.array(time_stats))
    print('Average training_time: {}'.format(mean_training_time))
    print('load ckpt with best validation accuracy from {}'.format(save_path))
    model.load_state_dict(torch.load(save_path, map_location='cpu'))
    test_accuracy = test(args, model, device, test_loader)

    writer.add_scalar("test_accuracy", test_accuracy, args.epochs)
    with open(os.path.join(args.output_path, 'accuracy.pkl'),
              'wb') as pkl_file:
        pkl.dump(
            {
                'train': best_accuracy,
                'test': test_accuracy,
                'training_time': mean_training_time
            }, pkl_file)
コード例 #2
0
def train(config):
    ## set up summary writer
    writer = SummaryWriter(config['output_path'])

    # set up early stop
    early_stop_engine = EarlyStopping(config["early_stop_patience"])

    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    prep_dict["target"] = prep.image_train( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
    else:
        prep_dict["test"] = prep.image_test( \
                            resize_size=prep_config["resize_size"], \
                            crop_size=prep_config["crop_size"])
               
    ## set loss
    class_num = config["network"]["params"]["class_num"]
    loss_params = config["loss"]

    class_criterion = nn.CrossEntropyLoss()
    transfer_criterion = loss.PADA
    center_criterion = loss_params["loss_type"](num_classes=class_num, 
                                       feat_dim=config["network"]["params"]["bottleneck_dim"])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    dsets["source"] = ImageList(stratify_sampling(open(data_config["source"]["list_path"]).readlines(), prep_config["source_size"]), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = util_data.DataLoader(dsets["source"], \
            batch_size=data_config["source"]["batch_size"], \
            shuffle=True, num_workers=1)
    dsets["target"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), prep_config["target_size"]), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = util_data.DataLoader(dsets["target"], \
            batch_size=data_config["target"]["batch_size"], \
            shuffle=True, num_workers=1)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"+str(i)] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["test"+str(i)] = util_data.DataLoader(dsets["test"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)

            dsets["target"+str(i)] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"]["val"+str(i)])
            dset_loaders["target"+str(i)] = util_data.DataLoader(dsets["target"+str(i)], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)
    else:
        dsets["test"] = ImageList(stratify_sampling(open(data_config["test"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"])
        dset_loaders["test"] = util_data.DataLoader(dsets["test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)

        dsets["target_test"] = ImageList(stratify_sampling(open(data_config["target"]["list_path"]).readlines(), ratio=prep_config['target_size']), \
                                transform=prep_dict["test"])
        dset_loaders["target_test"] = MyDataLoader(dsets["target_test"], \
                                batch_size=data_config["test"]["batch_size"], \
                                shuffle=False, num_workers=1)

    config['out_file'].write("dataset sizes: source={}, target={}\n".format(
        len(dsets["source"]), len(dsets["target"])))

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])


    use_gpu = torch.cuda.is_available()
    if use_gpu:
        base_network = base_network.cuda()

    ## collect parameters
    if net_config["params"]["new_cls"]:
        if net_config["params"]["use_bottleneck"]:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
        else:
            parameter_list = [{"params":base_network.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \
                            {"params":base_network.fc.parameters(), "lr_mult":10, 'decay_mult':2}]
    else:
        parameter_list = [{"params":base_network.parameters(), "lr_mult":1, 'decay_mult':2}]

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num())
    gradient_reverse_layer = network.AdversarialLayer(high_value=config["high"]) #, 
                                                      #max_iter_value=config["num_iterations"])
    if use_gpu:
        ad_net = ad_net.cuda()
    parameter_list.append({"params":ad_net.parameters(), "lr_mult":10, 'decay_mult':2})
    parameter_list.append({"params":center_criterion.parameters(), "lr_mult": 10, 'decay_mult':1})
 
    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optim_dict[optimizer_config["type"]](parameter_list, \
                    **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]


    ## train   
    len_train_source = len(dset_loaders["source"]) - 1
    len_train_target = len(dset_loaders["target"]) - 1
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == 0:
            base_network.train(False)
            if config['loss']['ly_type'] == "cosine":
                temp_acc = image_classification_test(dset_loaders, \
                    base_network, test_10crop=prep_config["test_10crop"], \
                    gpu=use_gpu)
            elif config['loss']['ly_type'] == "euclidean":
                temp_acc, _ = distance_classification_test(dset_loaders, \
                    base_network, center_criterion.centers.detach(), test_10crop=prep_config["test_10crop"], \
                    gpu=use_gpu)
            else:
                raise ValueError("no test method for cls loss: {}".format(config['loss']['ly_type']))
            
            snapshot_obj = {'step': i, 
                            "base_network": base_network.state_dict(), 
                            'precision': temp_acc, 
                            }
            if config["loss"]["loss_name"] != "laplacian" and config["loss"]["ly_type"] == "euclidean":
                snapshot_obj['center_criterion'] = center_criterion.state_dict()
            if temp_acc > best_acc:
                best_acc = temp_acc
                # save best model
                torch.save(snapshot_obj, 
                           osp.join(config["output_path"], "best_model.pth.tar"))
            log_str = "iter: {:05d}, {} precision: {:.5f}\n".format(i, config['loss']['ly_type'], temp_acc)
            config["out_file"].write(log_str)
            config["out_file"].flush()
            writer.add_scalar("precision", temp_acc, i)

            if early_stop_engine.is_stop_training(temp_acc):
                config["out_file"].write("no improvement after {}, stop training at step {}\n".format(
                    config["early_stop_patience"], i))
                # config["out_file"].write("finish training! \n")
                break

        if (i+1) % config["snapshot_interval"] == 0:
            torch.save(snapshot_obj, 
                       osp.join(config["output_path"], "iter_{:05d}_model.pth.tar".format(i)))
        

        ## train one iter
        base_network.train(True)
        optimizer = lr_scheduler(param_lr, optimizer, i, **schedule_param)
        optimizer.zero_grad()
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        if use_gpu:
            inputs_source, inputs_target, labels_source = \
                Variable(inputs_source).cuda(), Variable(inputs_target).cuda(), \
                Variable(labels_source).cuda()
        else:
            inputs_source, inputs_target, labels_source = Variable(inputs_source), \
                Variable(inputs_target), Variable(labels_source)
           
        inputs = torch.cat((inputs_source, inputs_target), dim=0)
        source_batch_size = inputs_source.size(0)

        if config['loss']['ly_type'] == 'cosine':
            features, logits = base_network(inputs)
            source_logits = logits.narrow(0, 0, source_batch_size)
        elif config['loss']['ly_type'] == 'euclidean':
            features, _ = base_network(inputs)
            logits = -1.0 * loss.distance_to_centroids(features, center_criterion.centers.detach())
            source_logits = logits.narrow(0, 0, source_batch_size)

        ad_net.train(True)
        weight_ad = torch.ones(inputs.size(0))
        transfer_loss = transfer_criterion(features, ad_net, gradient_reverse_layer, \
                                           weight_ad, use_gpu)
        ad_out, _ = ad_net(features.detach())
        ad_acc, source_acc_ad, target_acc_ad = domain_cls_accuracy(ad_out)

        # source domain classification task loss
        classifier_loss = class_criterion(source_logits, labels_source)
        # fisher loss on labeled source domain
        fisher_loss, fisher_intra_loss, fisher_inter_loss, center_grad = center_criterion(features.narrow(0, 0, int(inputs.size(0)/2)), labels_source, inter_class=config["loss"]["inter_type"], 
                                                                               intra_loss_weight=loss_params["intra_loss_coef"], inter_loss_weight=loss_params["inter_loss_coef"])
        # entropy minimization loss
        em_loss = loss.EntropyLoss(nn.Softmax(dim=1)(logits))

        # final loss
        total_loss = loss_params["trade_off"] * transfer_loss \
                     + fisher_loss \
                     + loss_params["em_loss_coef"] * em_loss \
                     + classifier_loss

        total_loss.backward()
        if center_grad is not None:
            # clear mmc_loss
            center_criterion.centers.grad.zero_()
            # Manually assign centers gradients other than using autograd
            center_criterion.centers.backward(center_grad)

        optimizer.step()

        if i % config["log_iter"] == 0:
            config['out_file'].write('iter {}: total loss={:0.4f}, transfer loss={:0.4f}, cls loss={:0.4f}, '
                'em loss={:0.4f}, '
                'mmc loss={:0.4f}, intra loss={:0.4f}, inter loss={:0.4f}, '
                'ad acc={:0.4f}, source_acc={:0.4f}, target_acc={:0.4f}\n'.format(
                i, total_loss.data.cpu().float().item(), transfer_loss.data.cpu().float().item(), classifier_loss.data.cpu().float().item(), 
                em_loss.data.cpu().float().item(), 
                fisher_loss.cpu().float().item(), fisher_intra_loss.cpu().float().item(), fisher_inter_loss.cpu().float().item(),
                ad_acc, source_acc_ad, target_acc_ad, 
                ))

            config['out_file'].flush()
            writer.add_scalar("total_loss", total_loss.data.cpu().float().item(), i)
            writer.add_scalar("cls_loss", classifier_loss.data.cpu().float().item(), i)
            writer.add_scalar("transfer_loss", transfer_loss.data.cpu().float().item(), i)
            writer.add_scalar("ad_acc", ad_acc, i)
            writer.add_scalar("d_loss/total", fisher_loss.data.cpu().float().item(), i)
            writer.add_scalar("d_loss/intra", fisher_intra_loss.data.cpu().float().item(), i)
            writer.add_scalar("d_loss/inter", fisher_inter_loss.data.cpu().float().item(), i)
        
    return best_acc