Пример #1
0
def main():
    # parse options
    parser = TestOptions()
    opts = parser.parse()

    # data loader
    print('--- load data ---')
    style = load_image(opts.style_name)
    if opts.gpu != 0:
        style = to_var(style)
    if opts.c2s == 1:
        content = load_image(opts.content_name, opts.content_type)
        if opts.gpu != 0:
            content = to_var(content)

    # model
    print('--- load model ---')
    tetGAN = TETGAN()
    tetGAN.load_state_dict(torch.load(opts.model))
    if opts.gpu != 0:
        tetGAN.cuda()
    tetGAN.eval()

    print('--- testing ---')
    if opts.c2s == 1:
        result = tetGAN(content, style)
    else:
        result = tetGAN.desty_forward(style)
    if opts.gpu != 0:
        result = to_data(result)

    print('--- save ---')
    # directory
    result_filename = os.path.join(opts.result_dir, opts.name)
    if not os.path.exists(opts.result_dir):
        os.mkdir(opts.result_dir)
    save_image(result[0], result_filename)
Пример #2
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # data loader
    print('--- load parameter ---')
    outer_iter = opts.outer_iter
    fade_iter = max(1.0, float(outer_iter / 2))
    epochs = opts.epoch
    batchsize = opts.batchsize
    datasize = opts.datasize
    datarange = opts.datarange
    augementratio = opts.augementratio
    centercropratio = opts.centercropratio

    # model
    print('--- create model ---')
    tetGAN = TETGAN(gpu=(opts.gpu != 0))
    if opts.gpu != 0:
        tetGAN.cuda()
    tetGAN.init_networks(weights_init)
    tetGAN.train()

    print('--- training ---')
    stylenames = os.listdir(opts.train_path)
    print('List of %d styles:' % (len(stylenames)), *stylenames, sep=' ')

    if opts.progressive == 1:
        # proressive training. From level1 64*64, to level2 128*128, to level3 256*256
        # level 1
        for i in range(outer_iter):
            jitter = min(1.0, i / fade_iter)
            fnames = load_trainset_batchfnames(opts.train_path, batchsize * 4,
                                               datarange, datasize * 2)
            for epoch in range(epochs):
                for fname in fnames:
                    x, y_real, y = prepare_batch(fname, 1, jitter,
                                                 centercropratio,
                                                 augementratio, opts.gpu)
                    losses = tetGAN.one_pass(x[0], None, y[0], None, y_real[0],
                                             None, 1, None)
                print('Level1, Iter[%d/%d], Epoch [%d/%d]' %
                      (i + 1, outer_iter, epoch + 1, epochs))
                print(
                    'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f'
                    % (losses[0], losses[1], losses[2], losses[3], losses[4]))
        # level 2
        for i in range(outer_iter):
            w = max(0.0, 1 - i / fade_iter)
            fnames = load_trainset_batchfnames(opts.train_path, batchsize * 2,
                                               datarange, datasize * 2)
            for epoch in range(epochs):
                for fname in fnames:
                    x, y_real, y = prepare_batch(fname, 2, 1, centercropratio,
                                                 augementratio, opts.gpu)
                    losses = tetGAN.one_pass(x[0], x[1], y[0], y[1], y_real[0],
                                             y_real[1], 2, w)
                print('Level2, Iter[%d/%d], Epoch [%d/%d]' %
                      (i + 1, outer_iter, epoch + 1, epochs))
                print(
                    'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f'
                    % (losses[0], losses[1], losses[2], losses[3], losses[4]))
        # level 3
        for i in range(outer_iter):
            w = max(0.0, 1 - i / fade_iter)
            fnames = load_trainset_batchfnames(opts.train_path, batchsize,
                                               datarange, datasize)
            for epoch in range(epochs):
                for fname in fnames:
                    x, y_real, y = prepare_batch(fname, 3, 1, centercropratio,
                                                 augementratio, opts.gpu)
                    losses = tetGAN.one_pass(x[0], x[1], y[0], y[1], y_real[0],
                                             y_real[1], 3, w)
                print('Level3, Iter[%d/%d], Epoch [%d/%d]' %
                      (i + 1, outer_iter, epoch + 1, epochs))
                print(
                    'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f'
                    % (losses[0], losses[1], losses[2], losses[3], losses[4]))
    else:
        # directly train on level3 256*256
        for i in range(outer_iter):
            fnames = load_trainset_batchfnames(opts.train_path, batchsize,
                                               datarange, datasize)
            for epoch in range(epochs):
                for fname in fnames:
                    x, y_real, y = prepare_batch(fname, 3, 1, centercropratio,
                                                 augementratio, opts.gpu)
                    losses = tetGAN.one_pass(x[0], None, y[0], None, y_real[0],
                                             None, 3, 0)
                print('Iter[%d/%d], Epoch [%d/%d]' %
                      (i + 1, outer_iter, epoch + 1, epochs))
                print(
                    'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f'
                    % (losses[0], losses[1], losses[2], losses[3], losses[4]))

    print('--- save ---')
    torch.save(tetGAN.state_dict(), opts.save_model_name)
Пример #3
0
def main():
    # parse options
    parser = FinetuneOptions()
    opts = parser.parse()

    # data loader
    print('--- load parameter ---')
    outer_iter = opts.outer_iter
    epochs = opts.epoch
    batchsize = opts.batchsize
    datasize = opts.datasize
    stylename = opts.style_name

    # model
    print('--- create model ---')
    tetGAN = TETGAN(gpu=(opts.gpu != 0))
    if opts.gpu != 0:
        tetGAN.cuda()
    tetGAN.load_state_dict(torch.load(opts.load_model_name))
    tetGAN.train()

    print('--- training ---')
    # supervised one shot learning
    if opts.supervise == 1:
        for i in range(outer_iter):
            fnames = load_oneshot_batchfnames(stylename, batchsize, datasize)
            for epoch in range(epochs):
                for fname in fnames:
                    x, y_real, y = prepare_batch(fname, 3, 0, 0, 0, opts.gpu)
                    losses = tetGAN.one_pass(x[0], None, y[0], None, y_real[0],
                                             None, 3, 0)
                print('Iter[%d/%d], Epoch [%d/%d]' %
                      (i + 1, outer_iter, epoch + 1, epochs))
                print(
                    'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f'
                    % (losses[0], losses[1], losses[2], losses[3], losses[4]))
    # unsupervised one shot learning
    else:
        for i in range(outer_iter):
            fnames = load_oneshot_batchfnames(stylename, batchsize, datasize)
            for epoch in range(epochs):
                for fname in fnames:
                    # no ground truth x provided
                    _, y_real, _ = prepare_batch(fname, 3, 0, 0, 0, opts.gpu)
                    Lsrec = tetGAN.update_style_autoencoder(y_real[0])
                for fname in fnames:
                    # no ground truth x provided
                    _, y_real, y = prepare_batch(fname, 3, 0, 0, 0, opts.gpu)
                    with torch.no_grad():
                        x_auxiliary = tetGAN.desty_forward(y_real[0])
                    losses = tetGAN.one_pass(x_auxiliary, None, y[0], None,
                                             y_real[0], None, 3, 0)
                print('Iter[%d/%d], Epoch [%d/%d]' %
                      (i + 1, outer_iter, epoch + 1, epochs))
                print(
                    'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f, Lsrec: %.3f'
                    % (losses[0], losses[1], losses[2], losses[3], losses[4],
                       Lsrec))

    print('--- save ---')
    torch.save(tetGAN.state_dict(), opts.save_model_name)
Пример #4
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # data loader
    print('--- load parameter ---')
    # outer_iter = opts.outer_iter
    # fade_iter = max(1.0, float(outer_iter / 2))
    epochs = opts.epoch
    batchsize = opts.batchsize
    # datasize = opts.datasize
    # datarange = opts.datarange
    augementratio = opts.augementratio
    centercropratio = opts.centercropratio

    # model
    print('--- create model ---')
    tetGAN = TETGAN(gpu=(opts.gpu != 0))
    if opts.gpu != 0:
        tetGAN.cuda()
    tetGAN.init_networks(weights_init)

    num_params = 0
    for param in tetGAN.parameters():
        num_params += param.numel()
    print('Total number of parameters in TET-GAN: %.3f M' % (num_params / 1e6))

    print('--- training ---')
    texture_class = 'base_gray_texture' in opts.dataset_class or 'skeleton_gray_texture' in opts.dataset_class
    if texture_class:
        tetGAN.load_state_dict(torch.load(opts.model))
        dataset_path = os.path.join(opts.train_path, opts.dataset_class,
                                    'style')
        val_dataset_path = os.path.join(opts.train_path, opts.dataset_class,
                                        'val')
        if 'base_gray_texture' in opts.dataset_class:
            few_size = 6
            batchsize = 2
            epochs = 1500
        elif 'skeleton_gray_texture' in opts.dataset_class:
            few_size = 30
            batchsize = 10
            epochs = 300
        fnames = load_trainset_batchfnames_dualnet(dataset_path,
                                                   batchsize,
                                                   few_size=few_size)
        val_fnames = sorted(os.listdir(val_dataset_path))
        style_fnames = sorted(os.listdir(dataset_path)[:few_size])
    else:
        dataset_path = os.path.join(opts.train_path, opts.dataset_class,
                                    'train')
        fnames = load_trainset_batchfnames_dualnet(dataset_path, batchsize)

    tetGAN.train()

    train_size = os.listdir(dataset_path)
    print('List of %d styles:' % (len(train_size)))

    result_dir = os.path.join(opts.result_dir, opts.dataset_class)
    if not os.path.exists(result_dir):
        os.mkdir(result_dir)

    for epoch in range(epochs):
        for idx, fname in enumerate(fnames):
            x, y_real, y = prepare_batch(fname, 1, 1, centercropratio,
                                         augementratio, opts.gpu)
            losses = tetGAN.one_pass(x[0], None, y[0], None, y_real[0], None,
                                     1, 0)
            if (idx + 1) % 100 == 0:
                print('Epoch [%d/%d], Iter [%d/%d]' %
                      (epoch + 1, epochs, idx + 1, len(fnames)))
                print(
                    'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f'
                    % (losses[0], losses[1], losses[2], losses[3], losses[4]))
        if texture_class and ((epoch + 1) % (epochs / 20)) == 0:
            outname = 'save/' + 'val_epoch' + str(
                epoch +
                1) + '_' + opts.dataset_class + '_' + opts.save_model_name
            print('--- save model Epoch [%d/%d] ---' % (epoch + 1, epochs))
            torch.save(tetGAN.state_dict(), outname)

            print('--- validating model [%d/%d] ---' % (epoch + 1, epochs))
            for val_idx, val_fname in enumerate(val_fnames):
                v_fname = os.path.join(val_dataset_path, val_fname)
                random.shuffle(style_fnames)
                sty_fname = style_fnames[0]
                s_fname = os.path.join(dataset_path, sty_fname)
                with torch.no_grad():
                    val_content = load_image_dualnet(v_fname, load_type=1)
                    val_sty = load_image_dualnet(s_fname, load_type=0)
                    if opts.gpu != 0:
                        val_content = val_content.cuda()
                        val_sty = val_sty.cuda()
                    result = tetGAN(val_content, val_sty)
                    if opts.gpu != 0:
                        result = to_data(result)
                    result_filename = os.path.join(
                        result_dir,
                        str(epoch) + '_' + val_fname)
                    print(result_filename)
                    save_image(result[0], result_filename)
        elif not texture_class and ((epoch + 1) % 2) == 0:
            outname = 'save/' + 'epoch' + str(epoch +
                                              1) + '_' + opts.save_model_name
            print('--- save model ---')
            torch.save(tetGAN.state_dict(), outname)