Example #1
0
if single_gpu_flag(opt):
    board = SummaryWriter(os.path.join('runs', opt.name))

prev_model = create_model(opt)
prev_model.cuda()


model = UNet(n_channels=4, n_classes=3)
if opt.distributed:
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.apply(weights_init('kaiming'))
model.cuda()


if opt.use_gan:
    discriminator = Discriminator()
    discriminator.apply(utils.weights_init('gaussian'))
    discriminator.cuda()
    adv_embedder = resnet18(pretrained=True)
    if opt.distributed:
        adv_embedder = torch.nn.SyncBatchNorm.convert_sync_batchnorm(adv_embedder)
    adv_embedder.train()
    adv_embedder.cuda()

if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
    load_checkpoint(model, opt.checkpoint)
    if opt.use_gan:
        load_checkpoint(discriminator, opt.checkpoint.replace("step_", "step_disc_"))
        load_checkpoint(adv_embedder, opt.checkpoint.replace("step_", "step_adv_embed_"))

model_module = model
Example #2
0
    'draft_path' : 'STL path'#"/data4/wangpengxiao/danbooru2017/original_STL",
    'save_path' : 'result path'#"/data4/wangpengxiao/danbooru2017/result" ,
    'img_size' : 270,
    're_size' : 256,
    'learning_rate' : 1e-5,#changed
    'gpus' : '[0,1,2,3]',
    'lr_steps' : [5, 10, 15, 20],
    "lr_decay" : 0.1,
    'lamda_L1' : 0.01,#changed
    'workers' : 16,
    'weight_decay' : 1e-4
})


Unet = UNet(in_channels=4, out_channels=3)
D = Discriminator(in_channels=3, out_channels=1)


writer.add_graph(Unet, (Variable(torch.randn(1,2,4,256,256), requires_grad=True)[0], Variable(torch.randn(1,2,3,224,224), requires_grad=True)[0]))

Unet = torch.nn.DataParallel(Unet, device_ids=eval(args.gpus)).cuda()

D = torch.nn.DataParallel(D, device_ids=eval(args.gpus)).cuda()

cudnn.benchmark = True # faster convolutions, but more memory


train_loader = torch.utils.data.DataLoader(
    ClothDataSet(
        args.train_path,
        args.sketch_path,
prev_model.cuda()

embedder_model = Embedder()
load_checkpoint(
    embedder_model,
    "../cp-vton/checkpoints/identity_train_64_dim/step_020000.pth")
image_embedder = embedder_model.embedder_b.cuda()

model = UNet(n_channels=4, n_classes=3)
if opt.distributed:
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.apply(weights_init('kaiming'))
model.cuda()

if opt.use_gan:
    discriminator = Discriminator()
    discriminator.apply(utils.weights_init('gaussian'))
    discriminator.cuda()

if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
    load_checkpoint(model, opt.checkpoint)
    if opt.use_gan:
        load_checkpoint(discriminator,
                        opt.checkpoint.replace("step_", "step_disc_"))

model_module = model
if opt.use_gan:
    discriminator_module = discriminator
if opt.distributed:
    model = torch.nn.parallel.DistributedDataParallel(
        model,
Example #4
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    gmm_model = GMM(opt)
    load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
    gmm_model.cuda()

    generator_model = UnetGenerator(25,
                                    4,
                                    6,
                                    ngf=64,
                                    norm_layer=nn.InstanceNorm2d)
    load_checkpoint(generator_model,
                    "checkpoints/tom_train_new_2/step_040000.pth")
    generator_model.cuda()

    embedder_model = Embedder()
    load_checkpoint(embedder_model,
                    "checkpoints/identity_train_64_dim/step_020000.pth")
    embedder_model = embedder_model.embedder_b.cuda()

    model = G()
    model.apply(utils.weights_init('kaiming'))
    model.cuda()

    if opt.use_gan:
        discriminator = Discriminator()
        discriminator.apply(utils.weights_init('gaussian'))
        discriminator.cuda()

    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
        load_checkpoint(model, opt.checkpoint)

    model_module = model
    if opt.use_gan:
        discriminator_module = discriminator
    if opt.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True)
        model_module = model.module
        if opt.use_gan:
            discriminator = torch.nn.parallel.DistributedDataParallel(
                discriminator,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            discriminator_module = discriminator.module

    if opt.use_gan:
        train_residual_old(opt,
                           train_loader,
                           model,
                           model_module,
                           gmm_model,
                           generator_model,
                           embedder_model,
                           board,
                           discriminator=discriminator,
                           discriminator_module=discriminator_module)
        if single_gpu_flag(opt):
            save_checkpoint(
                {
                    "generator": model_module,
                    "discriminator": discriminator_module
                }, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        train_residual_old(opt, train_loader, model, model_module, gmm_model,
                           generator_model, embedder_model, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #5
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'TOM':

        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.cuda()
        # if opt.distributed:
        #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module

        train_tom(opt, train_loader, model, model_module, gmm_model, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    elif opt.stage == 'TOM+WARP':

        gmm_model = GMM(opt)
        gmm_model.cuda()

        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.cuda()
        # if opt.distributed:
        #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        gmm_model_module = gmm_model
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            gmm_model = torch.nn.parallel.DistributedDataParallel(
                gmm_model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            gmm_model_module = gmm_model.module

        train_tom_gmm(opt, train_loader, model, model_module, gmm_model,
                      gmm_model_module, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))

    elif opt.stage == "identity":
        model = Embedder()
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_identity_embedding(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'residual':

        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        generator_model = UnetGenerator(25,
                                        4,
                                        6,
                                        ngf=64,
                                        norm_layer=nn.InstanceNorm2d)
        load_checkpoint(generator_model,
                        "checkpoints/tom_train_new/step_038000.pth")
        generator_model.cuda()

        embedder_model = Embedder()
        load_checkpoint(embedder_model,
                        "checkpoints/identity_train_64_dim/step_020000.pth")
        embedder_model = embedder_model.embedder_b.cuda()

        model = UNet(n_channels=4, n_classes=3)
        if opt.distributed:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.apply(utils.weights_init('kaiming'))
        model.cuda()

        if opt.use_gan:
            discriminator = Discriminator()
            discriminator.apply(utils.weights_init('gaussian'))
            discriminator.cuda()

            acc_discriminator = AccDiscriminator()
            acc_discriminator.apply(utils.weights_init('gaussian'))
            acc_discriminator.cuda()

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
            if opt.use_gan:
                load_checkpoint(discriminator,
                                opt.checkpoint.replace("step_", "step_disc_"))

        model_module = model
        if opt.use_gan:
            discriminator_module = discriminator
            acc_discriminator_module = acc_discriminator

        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            if opt.use_gan:
                discriminator = torch.nn.parallel.DistributedDataParallel(
                    discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                discriminator_module = discriminator.module

                acc_discriminator = torch.nn.parallel.DistributedDataParallel(
                    acc_discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                acc_discriminator_module = acc_discriminator.module

        if opt.use_gan:
            train_residual(opt,
                           train_loader,
                           model,
                           model_module,
                           gmm_model,
                           generator_model,
                           embedder_model,
                           board,
                           discriminator=discriminator,
                           discriminator_module=discriminator_module,
                           acc_discriminator=acc_discriminator,
                           acc_discriminator_module=acc_discriminator_module)

            if single_gpu_flag(opt):
                save_checkpoint(
                    {
                        "generator": model_module,
                        "discriminator": discriminator_module
                    },
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
        else:
            train_residual(opt, train_loader, model, model_module, gmm_model,
                           generator_model, embedder_model, board)
            if single_gpu_flag(opt):
                save_checkpoint(
                    model_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
    elif opt.stage == "residual_old":
        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        generator_model = UnetGenerator(25,
                                        4,
                                        6,
                                        ngf=64,
                                        norm_layer=nn.InstanceNorm2d)
        load_checkpoint(generator_model,
                        "checkpoints/tom_train_new_2/step_070000.pth")
        generator_model.cuda()

        embedder_model = Embedder()
        load_checkpoint(embedder_model,
                        "checkpoints/identity_train_64_dim/step_020000.pth")
        embedder_model = embedder_model.embedder_b.cuda()

        model = UNet(n_channels=4, n_classes=3)
        if opt.distributed:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.apply(utils.weights_init('kaiming'))
        model.cuda()

        if opt.use_gan:
            discriminator = Discriminator()
            discriminator.apply(utils.weights_init('gaussian'))
            discriminator.cuda()

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        if opt.use_gan:
            discriminator_module = discriminator
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            if opt.use_gan:
                discriminator = torch.nn.parallel.DistributedDataParallel(
                    discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                discriminator_module = discriminator.module

        if opt.use_gan:
            train_residual_old(opt,
                               train_loader,
                               model,
                               model_module,
                               gmm_model,
                               generator_model,
                               embedder_model,
                               board,
                               discriminator=discriminator,
                               discriminator_module=discriminator_module)
            if single_gpu_flag(opt):
                save_checkpoint(
                    {
                        "generator": model_module,
                        "discriminator": discriminator_module
                    },
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
        else:
            train_residual_old(opt, train_loader, model, model_module,
                               gmm_model, generator_model, embedder_model,
                               board)
            if single_gpu_flag(opt):
                save_checkpoint(
                    model_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
def main(args):
    input_path = os.path.join(args.data, "input")
    trimap_path = os.path.join(args.data, "trimap")
    target_path = os.path.join(args.data, "target")
    output_path = os.path.join(args.data, "output")

    train_data_update_freq = args.batch_size
    test_data_update_freq = 50 * args.batch_size
    sess_save_freq = 100 * args.batch_size

    if not os.path.isdir(output_path):
        os.makedirs(output_path)

    if not os.path.isdir(args.logdir):
        os.makedirs(args.logdir)

    ids = [[int(i) for i in os.path.splitext(filename)[0].split('_')]
           for filename in os.listdir(input_path)]
    np.random.shuffle(ids)
    split_point = int(round(
        0.85 * len(ids)))  #using 70% as training and 30% as Validation
    train_ids = tf.get_variable('train_ids',
                                initializer=ids[0:split_point],
                                trainable=False)
    valid_ids = tf.get_variable('valid_ids',
                                initializer=ids[split_point:len(ids)],
                                trainable=False)

    global_step = tf.get_variable('global_step',
                                  initializer=0,
                                  trainable=False)

    g_iter = int(args.gen_epoch * int(train_ids.shape[0]))
    d_iter = int(args.disc_epoch * int(train_ids.shape[0]))
    a_iter = int(args.adv_epoch * int(train_ids.shape[0]))
    n_iter = g_iter + d_iter + a_iter

    input_images = tf.placeholder(tf.float32, shape=[None, 480, 360, 4])
    target_images = tf.placeholder(tf.float32, shape=[None, 480, 360, 4])
    alpha = target_images[:, :, :, 3][..., np.newaxis]

    with tf.variable_scope("Gen"):
        gen = UNet(4, 4)
        output = tf.sigmoid(gen(input_images))
        g_loss = tf.losses.mean_squared_error(target_images, output)
    with tf.variable_scope("Disc"):
        disc = Discriminator(4)
        d_real = disc(target_images)
        d_fake = disc(output)
        d_loss = tf.reduce_mean(tf.log(d_real) + tf.log(1 - d_fake))

    a_loss = g_loss + args.d_coeff * d_loss

    g_loss_summary = tf.summary.scalar("g_loss", g_loss)
    d_loss_summary = tf.summary.scalar("d_loss", d_loss)
    a_loss_summary = tf.summary.scalar("a_loss", a_loss)

    summary_op = tf.summary.merge(
        [g_loss_summary, d_loss_summary, a_loss_summary])

    summary_image = tf.summary.image("result", output)

    g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Gen')
    d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Disc')

    g_optimizer = tf.train.AdadeltaOptimizer(args.lr).minimize(
        g_loss, global_step=global_step, var_list=g_vars)
    a_optimizer = tf.train.AdadeltaOptimizer(args.lr).minimize(
        a_loss, global_step=global_step, var_list=g_vars)
    d_optimizer = tf.train.AdadeltaOptimizer(args.lr).minimize(
        -d_loss, global_step=global_step, var_list=d_vars)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    train_writer = tf.summary.FileWriter(args.logdir + '/train')
    test_writer = tf.summary.FileWriter(args.logdir + '/test')
    saver = tf.train.Saver()
    if args.checkpoint is not None and os.path.exists(
            os.path.join(args.logdir, 'checkpoint')):
        if args.checkpoint == -1:  #latest checkpoint
            saver.restore(sess, tf.train.latest_checkpoint(args.logdir))
        else:  #Specified checkpoint
            saver.restore(
                sess,
                os.path.join(args.logdir,
                             model_name + ".ckpt-" + str(args.checkpoint)))
        logging.debug('Model restored to step ' + str(global_step.eval(sess)))

    train_ids = list(train_ids.eval(sess))
    valid_ids = list(valid_ids.eval(sess))

    def load_batch(batch_ids):
        images, targets = [], []
        for i, j in batch_ids:
            input_filename = os.path.join(input_path,
                                          str(i) + '_' + str(j) + '.jpg')
            trimap_filename = os.path.join(trimap_path, str(i) + '_trimap.jpg')
            target_filename = os.path.join(target_path, str(i) + '.png')
            logging.debug(input_filename)
            logging.debug(trimap_filename)
            logging.debug(target_filename)
            image = resize(Image.open(input_filename), 2)
            trimap = resize(Image.open(trimap_filename), 2)
            target = resize(Image.open(target_filename), 2)

            image = np.array(image)
            trimap = np.array(trimap)[..., np.newaxis]
            image = np.concatenate(
                (image, trimap), axis=2).astype(np.float32) / 255

            target = np.array(target).astype(np.float32) / 255

            images.append(image)
            targets.append(target)

        return np.asarray(images), np.asarray(targets)

    def test_step(batch_idx, summary_fct):
        batch_range = random.sample(train_ids, args.batch_size)

        images, targets = load_batch(batch_range)

        loss, demo, summary = sess.run([g_loss, summary_image, summary_fct],
                                       feed_dict={
                                           input_images: images,
                                           target_images: targets,
                                       })

        test_writer.add_summary(summary, batch_idx)
        test_writer.add_summary(demo, batch_idx)

        logging.info('Validation Loss: {:.8f}'.format(loss))

    try:
        batch_idx = 0
        while batch_idx < n_iter:
            batch_idx = global_step.eval(sess) * args.batch_size

            loss_fct = None
            label = None
            optimizers = []
            if batch_idx < g_iter:
                loss_fct = g_loss
                summary_fct = g_loss_summary
                label = 'Gen train'
                optimizers = [g_optimizer]
            elif batch_idx < g_iter + d_iter:
                loss_fct = d_loss
                summary_fct = d_loss_summary
                label = 'Disc train'
                optimizers = [d_optimizer]
            else:
                loss_fct = a_loss
                summary_fct = summary_op
                label = 'Adv train'
                optimizers = [a_optimizer]

            batch_range = random.sample(train_ids, args.batch_size)
            images, targets = load_batch(batch_range)

            loss, summary = sess.run([loss_fct, summary_fct] + optimizers,
                                     feed_dict={
                                         input_images: np.array(images),
                                         target_images: np.array(targets)
                                     })[0:2]

            if batch_idx % train_data_update_freq == 0:
                logging.info('{}: [{}/{} ({:.0f}%)]\tGen Loss: {:.8f}'.format(
                    label, batch_idx, n_iter, 100. * (batch_idx + 1) / n_iter,
                    loss))

                train_writer.add_summary(summary, batch_idx)

            if batch_idx % test_data_update_freq == 0:
                test_step(batch_idx, summary_fct)

            if batch_idx % sess_save_freq == 0:
                logging.debug('Saving model')
                saver.save(sess,
                           os.path.join(args.logdir, model_name + ".ckpt"),
                           global_step=batch_idx)

    except Exception:
        saver.save(sess,
                   os.path.join(args.logdir,
                                'crash_save_' + model_name + ".ckpt"),
                   global_step=batch_idx)

    saver.save(sess,
               os.path.join(args.logdir, model_name + ".ckpt"),
               global_step=batch_idx)