Пример #1
0
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # TODO change data_range to include all train/evaluation/test data.
    # TODO adjust batch_size.
    train_data = FacadeDataset(flag='train', data_range=(0, 20), onehot=False)
    train_loader = DataLoader(train_data, batch_size=1)
    test_data = FacadeDataset(flag='test_dev',
                              data_range=(0, 114),
                              onehot=False)
    test_loader = DataLoader(test_data, batch_size=1)
    ap_data = FacadeDataset(flag='test_dev', data_range=(0, 114), onehot=True)
    ap_loader = DataLoader(ap_data, batch_size=1)

    name = 'starter_net'
    net = Net().to(device)
    criterion = nn.CrossEntropyLoss()  #TODO decide loss
    optimizer = torch.optim.Adam(net.parameters(), 1e-3, weight_decay=1e-5)

    print('\nStart training')
    for epoch in range(10):  #TODO decide epochs
        print('-----------------Epoch = %d-----------------' % (epoch + 1))
        train(train_loader, net, criterion, optimizer, device, epoch + 1)
        # TODO create your evaluation set, load the evaluation set and test on evaluation set
        evaluation_loader = train_loader
        test(evaluation_loader, net, criterion, device)

    print('\nFinished Training, Testing on test set')
    test(test_loader, net, criterion, device)
    print('\nGenerating Unlabeled Result')
    result = get_result(test_loader, net, device, folder='output_test')

    torch.save(net.state_dict(), './models/model_{}.pth'.format(name))

    cal_AP(ap_loader, net, criterion, device)
Пример #2
0
def train(num_epoch, batch_size, learning_rate, l1_weight):

    train_data = FacadeDataset(flag='train', data_range=(0, 49000))
    train_loader = DataLoader(train_data, batch_size=batch_size)
    val_data = FacadeDataset(flag='train', data_range=(49000, 50000))
    val_loader = DataLoader(val_data, batch_size=batch_size)
    # visual_data = next(iter(DataLoader(FacadeDataset(flag='train', data_range=(4900, 4909)), batch_size=9)))
    visual_data = next(iter(DataLoader(val_data, batch_size=9)))

    GAN_cifar = Cifar10_GAN(learning_rate=learning_rate, l1_weight=l1_weight)
    Unet = U_Net32(learning_rate=learning_rate)

    assert GAN_cifar.trained_epoch + 1 < num_epoch
    start_epoch = 0
    if GAN_cifar.trained_epoch > 0:
        start_epoch = GAN_cifar.trained_epoch + 1

    for epoch in range(start_epoch, num_epoch):
        print(
            '----------------------------- epoch {:d} -----------------------------'
            .format(epoch + 1))
        GAN_cifar.train_one_epoch(train_loader, val_loader, epoch)
        Unet.train_one_epoch(train_loader, val_loader, epoch)

        visual_result(visual_data[0], visual_data[1], GAN_cifar.G_model,
                      Unet.model, epoch + 1)

    GAN_cifar.plot_loss()
    Unet.plot_loss()

    GAN_cifar.save()
    Unet.save()
Пример #3
0
def main():
    ###### For part 1 uncomment the classifi() function!!########
    #classifi()




    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # TODO change data_range to include all train/evaluation/test data.
    # TODO adjust batch_size.
    train_data = FacadeDataset(flag='train', data_range=(0,725), onehot=False)
    train_loader = DataLoader(train_data, batch_size=1)

    train_val =  FacadeDataset(flag='train', data_range=(725,906), onehot=False)
    train_val_loader = DataLoader(train_val, batch_size=1)
    
    test_data = FacadeDataset(flag='test_dev', data_range=(0,114), onehot=False)
    test_loader = DataLoader(test_data, batch_size=1)
    ap_data = FacadeDataset(flag='test_dev', data_range=(0,114), onehot=True)
    ap_loader = DataLoader(ap_data, batch_size=1)
    name = 'starter_net'
    net = Net().to(device)
    criterion = nn.CrossEntropyLoss() #TODO decide loss
    optimizer = torch.optim.Adam(net.parameters(), 1e-3, weight_decay=1e-5)

    print('\nStart training')
    train_loss = []
    test_loss = []
    for epoch in range(20): #TODO decide epochs
        print('-----------------Epoch = %d-----------------' % (epoch+1))
        batch_loss = train(train_loader, net, criterion, optimizer, device, epoch+1)
        # TODO create your evaluation set, load the evaluation set and test on evaluation set
        evaluation_loader = train_val_loader
        test_batch,_ = test(evaluation_loader, net, criterion, device)
        result_eval = get_result(evaluation_loader, net, device, folder='output_test')
        train_loss.append(np.mean(batch_loss))
        test_loss.append(np.mean(test_batch))
    plt.figure()
    plt.plot(np.arange(0, 20), train_loss, label="train_loss")
    plt.plot(np.arange(0, 20), test_loss, label="test_loss")
    plt.title("Training Loss and Accuracy on Dataset")
    plt.xlabel("20 epoches")
    plt.ylabel("Lossy")
    plt.legend(loc="lower left")
    plt.savefig("plot_loss.png")

    print('\nFinished Training, Testing on test set')
    test(test_loader, net, criterion, device)
    print('\nGenerating Unlabeled Result')
    result = get_result(test_loader, net, device, folder='output_train')

    torch.save(net.state_dict(), './models/model_{}.pth'.format(name))


    cal_AP(ap_loader, net, criterion, device)
Пример #4
0
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    # TODO change data_range to include all train/evaluation/test data.
    # TODO adjust batch_size.
    print()
    train_data = FacadeDataset(flag='train', data_range=(0,720), onehot=False)
    train_loader = DataLoader(train_data, batch_size=10)
    test_data = FacadeDataset(flag='test_dev', data_range=(0,114), onehot=False)
    test_loader = DataLoader(test_data, batch_size=1)
    val_data = FacadeDataset(flag='train', data_range=(720,906), onehot=False)
    val_loader = DataLoader(test_data, batch_size=10)
    ap_data = FacadeDataset(flag='test_dev', data_range=(0,114), onehot=True)
    ap_loader = DataLoader(ap_data, batch_size=1)
    print(ap_data.dataset[0][0].shape)

    name = 'starter_net'
    net = Net().to(device)
    criterion = nn.CrossEntropyLoss() #TODO decide loss
    optimizer = torch.optim.Adam(net.parameters(), 1e-3, weight_decay=1e-5)

    print('\nStart training')
    val_loss = []
    t_loss = []
    for epoch in range(15): #TODO decide epochs
        print('-----------------Epoch = %d-----------------' % (epoch+1))
        train(train_loader, net, criterion, optimizer, device, epoch+1)
        # TODO create your evaluation set, load the evaluation set and test on evaluation set
        
        valloss = test(val_loader, net, criterion, device)
        trainloss = test(train_loader, net, criterion, device)
        val_loss.append(valloss)
        t_loss.append(trainloss)

    print('\nFinished Training, Testing on test set')
    test(test_loader, net, criterion, device)

    print('\nGenerating Unlabeled Result')
    result = get_result(test_loader, net, device, folder='output_test')

    torch.save(net.state_dict(), './models/model_{}.pth'.format(name))

    cal_AP(ap_loader, net, criterion, device)
    plt.plot(range(1,16), t_loss, 'b',label='Training loss')
    plt.plot(range(1,16), val_loss, 'r',label='Valitation loss')
    plt.show()
    gen_data = FacadeDataset(flag='eval', data_range=(0,1), onehot=True)
    gen_loader = DataLoader(gen_data, batch_size=1)
    rst = get_result(gen_loader, net, device, folder='output_eval')
Пример #5
0
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    train_data = FacadeDataset(flag='train', data_range=(0, 905), onehot=False)
    print("finish loading train&val")
    trainS, val = torch.utils.data.random_split(train_data, [724, 181])
    # train_data = FacadeDataset(flag='train', data_range=(0, 5), onehot=False)
    # print("finish loading train&val")
    # trainS, val = torch.utils.data.random_split(train_data, [4, 1])
    train_loader = DataLoader(trainS, batch_size=4)
    eval_loader = DataLoader(val, batch_size=4)

    test_data = FacadeDataset(flag='test_dev', data_range=(0, 5), onehot=False)
    test_loader = DataLoader(test_data, batch_size=1)
    ap_data = FacadeDataset(flag='test_dev', data_range=(0, 5), onehot=True)
    ap_loader = DataLoader(ap_data, batch_size=1)

    name = 'FacadeCNN'
    net = Net().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), 1e-3, weight_decay=1e-5)

    pre_train_acc = test(train_loader, net, criterion, device)
    train_acc.append(pre_train_acc)
    pre_val_acc = test(eval_loader, net, criterion, device)
    val_acc.append(pre_val_acc)

    print('\nStart training')
    for epoch in range(1):
        print('-----------------Epoch = %d-----------------' % (epoch + 1))
        tr_acc = train(train_loader, net, criterion, optimizer, device,
                       epoch + 1)
        train_acc.append(tr_acc)

        evaluation_loader = eval_loader
        v_acc = test(evaluation_loader, net, criterion, device)
        val_acc.append(v_acc)

    print('\nFinished Training, Testing on test set')
    testloss = test(test_loader, net, criterion, device)
    print("\nLoss calculated on test set = ", testloss)
    print('\nGenerating Unlabeled Result')
    get_result(test_loader, net, device, folder='output_test')

    torch.save(net.state_dict(), './models/model_{}.pth'.format(name))

    cal_AP(ap_loader, net, criterion, device)
Пример #6
0
def test(batch_size, test_range, size=32, only_visual=False):
    if size == 32:
        test_date = FacadeDataset(flag='test', data_range=(0, test_range))
        # GAN_model = Cifar10_GAN()
        # Unet = U_Net32()
        GAN_model = Cifar10_GAN(save_path="trained_model_best.pth.tar")
        Unet = U_Net32(save_path="trained_UNet_model_best.pth.tar")
    else:
        test_date = FacadeDataset_256(flag='train', data_range=(0, test_range))
        GAN_model = GAN_256()
        Unet = Unet_256()

    test_loader = DataLoader(test_date, batch_size=batch_size)
    visual_data = next(
        iter(
            DataLoader(test_date,
                       batch_size=9,
                       sampler=sampler.SubsetRandomSampler(
                           range(test_range)))))
    if not only_visual:
        GAN_model.test(test_loader)
        Unet.test(test_loader)

    # visual_result(visual_data[0], visual_data[1], GAN_model.G_model, Unet.model, mode='test')
    visual_result_four(visual_data[0],
                       visual_data[1],
                       GAN_model.G_model,
                       Unet.model,
                       mode='test')
Пример #7
0
def eval_pretrained():
    """I trained a network on a VM and then wanted to retest it locally"""
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    val_data = FacadeDataset(flag='train', data_range=(700, 900), onehot=False)
    val_loader = DataLoader(val_data, batch_size=1)
    ap_data = FacadeDataset(flag='test_dev', data_range=(0, 100), onehot=True)
    ap_loader = DataLoader(ap_data, batch_size=1)

    criterion = nn.CrossEntropyLoss()

    model_weight_path = 'vm_net.pth'

    net = u_net()
    net.load_state_dict(
        torch.load(model_weight_path, map_location=torch.device('cpu')))

    result = get_result(val_loader, net, device, folder='output_test')
    cal_AP(ap_loader, net, criterion, device)
Пример #8
0
def main():
    wandb.init(project="proj4-part2")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # TODO change data_range to include all train/evaluation/test data.
    # TODO adjust batch_size.
    dataset = FacadeDataset(flag='train', data_range=(0, 905), onehot=False)
    trainProportion = 0.8
    trainSize = int(len(dataset) * trainProportion)
    valSize = len(dataset) - trainSize
    train_data, val_data = torch.utils.data.random_split(
        dataset, [trainSize, valSize])
    train_loader = DataLoader(train_data, batch_size=16)
    val_loader = DataLoader(val_data, batch_size=16)
    test_data = FacadeDataset(flag='test_dev',
                              data_range=(0, 114),
                              onehot=False)
    test_loader = DataLoader(test_data, batch_size=1)
    ap_data = FacadeDataset(flag='test_dev', data_range=(0, 114), onehot=True)
    ap_loader = DataLoader(ap_data, batch_size=1)

    name = 'starter_net'
    net = Net().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), 1e-3, weight_decay=1e-5)
    wandb.watch(net)

    print('\nStart training')
    for epoch in range(60):  #TODO decide epochs
        print('-----------------Epoch = %d-----------------' % (epoch + 1))
        train(train_loader, net, criterion, optimizer, device, epoch + 1)
        # TODO create your evaluation set, load the evaluation set and test on evaluation set
        evaluation_loader = val_loader
        testwb(train_loader, evaluation_loader, net, criterion, device)

    print('\nFinished Training, Testing on test set')
    test(test_loader, net, criterion, device)
    print('\nGenerating Unlabeled Result')
    result = get_result(test_loader, net, device, folder='output_test')

    torch.save(net.state_dict(), './models/model_{}.pth'.format(name))

    cal_AP(ap_loader, net, criterion, device)
Пример #9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchsize', '-b', type=int, default=50)
    parser.add_argument('--epoch', '-e', type=int, default=1000)
    parser.add_argument('--gpu', '-g', type=int, default=-1)
    parser.add_argument('--out', '-o', default='')
    parser.add_argument('--resume', '-r', default='')
    parser.add_argument('--n_hidden', '-n', type=int, default=100)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--snapshot_interval', type=int, default=100000)
    parser.add_argument('--display_interval', type=int, default=100)
    args = parser.parse_args()

    out_dir = 'result'
    if args.out != '':
        out_dir = '{}/{}'.format(out, args.out)
    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# n_hidden: {}'.format(args.n_hidden))
    print('# epoch: {}'.format(args.epoch))
    print('# out: {}'.format(out_dir))
    print('')

    bottom_ch = 512
    unet = UNet([
        DownBlock(None, bottom_ch // 8),
        DownBlock(None, bottom_ch // 4),
        DownBlock(None, bottom_ch // 2)
    ], BottomBlock(None, bottom_ch), [
        UpBlock(None, bottom_ch // 2),
        UpBlock(None, bottom_ch // 4),
        UpBlock(None, bottom_ch // 8)
    ], L.Convolution2D(None, 12, 3, 1, 1))

    model = L.Classifier(unet)

    if args.gpu >= 0:
        chainer.backends.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    print('Loading Data...')
    images, labels = get_facade()
    print('Transforming Images...')
    images = transfrom_images(images)
    print('Transforming Labels...')
    labels = transform_labels(labels)

    train, test = (labels[:300], images[:300]), (labels[300:], images[300:])
    train, test = FacadeDataset(train[1],
                                train[0]), FacadeDataset(test[1], test[0])
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test,
                                                 args.batchsize,
                                                 repeat=False,
                                                 shuffle=False)

    snapshot_interval = (args.snapshot_interval, 'iteration')
    generateimage_interval = (args.snapshot_interval // 100, 'iteration')
    display_interval = (args.display_interval, 'iteration')

    print('Setting trainer...')
    updater = training.updater.StandardUpdater(train_iter,
                                               optimizer,
                                               device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=out_dir)
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'iteration', 'main/loss', 'main/accuracy']),
                   trigger=display_interval)
    trainer.extend(extensions.LogReport())
    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch',
                                  file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch',
                file_name='accuracy.png'))
    trainer.extend(extensions.ProgressBar(update_interval=20))

    print('RUN')
    trainer.run()
Пример #10
0
def main():
    parser = argparse.ArgumentParser(
        description="chainer implementation of pix2pix")
    parser.add_argument("--batchsize",
                        "-b",
                        type=int,
                        default=1,
                        help="Number of images in each mini-batch")
    parser.add_argument("--epoch",
                        "-e",
                        type=int,
                        default=40000,
                        help="Number of sweeps over the dataset to train")
    parser.add_argument("--gpu",
                        "-g",
                        type=int,
                        default=-1,
                        help="GPU ID (negative value indicates CPU)")
    parser.add_argument("--dataset",
                        "-i",
                        default="./input/png/",
                        help="Directory of image files.")
    parser.add_argument("--out",
                        "-o",
                        default="D:/output/imasUtaConverter/",
                        help="Directory to output the result")
    parser.add_argument("--resume",
                        "-r",
                        default="",
                        help="Resume the training from snapshot")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--snapshot_interval",
                        type=int,
                        default=10000,
                        help="Interval of snapshot")
    parser.add_argument("--display_interval",
                        type=int,
                        default=20,
                        help="Interval of displaying log to console")
    args = parser.parse_args()

    print("GPU: {}".format(args.gpu))
    print("# Minibatch-size: {}".format(args.batchsize))
    print("# epoch: {}".format(args.epoch))
    print("")

    # Set up a neural network to train
    enc = Encoder(in_ch=2)
    dec = Decoder(out_ch=2)
    dis = Discriminator(in_ch=2, out_ch=2)

    if args.gpu >= 0:
        chainer.backends.cuda.get_device_from_id(
            args.gpu).use()  # Make a specified GPU current
        enc.to_gpu()  # Copy the model to the GPU
        dec.to_gpu()
        dis.to_gpu()

    # Setup an optimizer
    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(0.00001), "hook_dec")
        return optimizer

    opt_enc = make_optimizer(enc)
    opt_dec = make_optimizer(dec)
    opt_dis = make_optimizer(dis)

    train_d = FacadeDataset(args.dataset, data_range=(0, 38))
    test_d = FacadeDataset(args.dataset, data_range=(38, 40))
    #train_iter = chainer.iterators.MultiprocessIterator(train_d, args.batchsize, n_processes=4)
    #test_iter = chainer.iterators.MultiprocessIterator(test_d, args.batchsize, n_processes=4)
    train_iter = chainer.iterators.SerialIterator(train_d, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test_d, args.batchsize)

    # Set up a trainer
    updater = FacadeUpdater(models=(enc, dec, dis),
                            iterator={
                                "main": train_iter,
                                "test": test_iter
                            },
                            optimizer={
                                "enc": opt_enc,
                                "dec": opt_dec,
                                "dis": opt_dis
                            },
                            device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, "epoch"), out=args.out)

    snapshot_interval = (args.snapshot_interval, "iteration")
    display_interval = (args.display_interval, "iteration")
    trainer.extend(
        extensions.snapshot(filename="snapshot_iter_{.updater.iteration}.npz"),
        trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        enc, "enc_iter_{.updater.iteration}.npz"),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dec, "dec_iter_{.updater.iteration}.npz"),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, "dis_iter_{.updater.iteration}.npz"),
                   trigger=snapshot_interval)
    trainer.extend(extensions.LogReport(trigger=display_interval))
    trainer.extend(extensions.PrintReport([
        "epoch",
        "iteration",
        "enc/loss",
        "dec/loss",
        "dis/loss",
    ]),
                   trigger=display_interval)
    #trainer.extend(extensions.PlotReport(["enc/loss", "dis/loss"], x_key="epoch", file_name="loss.png"))
    trainer.extend(extensions.ProgressBar(update_interval=20))
    trainer.extend(out_image(updater, enc, dec, 1, 10, args.seed, args.out),
                   trigger=snapshot_interval)

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
Пример #11
0
def main():
    parser = argparse.ArgumentParser(
        description='chainer implementation of pix2pix')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=4,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=200,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--dataset',
                        '-i',
                        default='./azuren',
                        help='Directory of image files.')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--snapshot_interval',
                        type=int,
                        default=250,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval',
                        type=int,
                        default=100,
                        help='Interval of displaying log to console')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Set up a neural network to train
    enc = Encoder(in_ch=1)
    dec = Decoder(out_ch=3)
    dis = Discriminator(in_ch=1, out_ch=3)

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()  # Make a specified GPU current
        enc.to_gpu()  # Copy the model to the GPU
        dec.to_gpu()
        dis.to_gpu()

    # Setup an optimizer
    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(0.00001), 'hook_dec')
        return optimizer

    opt_enc = make_optimizer(enc)
    opt_dec = make_optimizer(dec)
    opt_dis = make_optimizer(dis)

    train_d = FacadeDataset(args.dataset, data_range=(1, 2000))
    test_d = FacadeDataset(args.dataset, data_range=(2000, 2400))
    train_iter = chainer.iterators.MultiprocessIterator(train_d,
                                                        args.batchsize,
                                                        n_processes=2)
    test_iter = chainer.iterators.MultiprocessIterator(test_d,
                                                       args.batchsize,
                                                       n_processes=2)
    #train_iter = chainer.iterators.SerialIterator(train_d, args.batchsize)
    #test_iter = chainer.iterators.SerialIterator(test_d, args.batchsize)

    # Set up a trainer
    updater = FacadeUpdater(models=(enc, dec, dis),
                            iterator={
                                'main': train_iter,
                                'test': test_iter
                            },
                            optimizer={
                                'enc': opt_enc,
                                'dec': opt_dec,
                                'dis': opt_dis
                            },
                            device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    snapshot_interval = (args.snapshot_interval, 'iteration')
    display_interval = (args.display_interval, 'iteration')
    trainer.extend(
        extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
        trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        enc, 'enc_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dec, 'dec_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    trainer.extend(extensions.LogReport(trigger=display_interval))
    trainer.extend(extensions.PrintReport([
        'epoch',
        'iteration',
        'enc/loss',
        'dec/loss',
        'dis/loss',
    ]),
                   trigger=display_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(out_image(updater, enc, dec, 5, 5, args.seed, args.out),
                   trigger=snapshot_interval)

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
def load_test_dataset():
    test_dataset = FacadeDataset(os.path.join(opt.dataroot, 'extended'),
                                 root='test')
    print(f'test_dataset size={len(test_dataset)}')
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=1)
    return test_loader