Example #1
0
def train_tryon(opt, train_loader, model):

    model.cuda()
    model.train()

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()

        inputs = train_loader.next_batch()
        print(step, '++++++')
        print("-------------------------------------")
        pairs = inputs["pair"]
        # c_name = inputs["c_name"]
        model.set_input(inputs)
        model.optimize_parameters()
        results = model.current_results()

        print('step: %8d, G_loss: %4f, c_loss: %4f' %
              (step + 1, results['G_loss'].item(),
               results['content_loss'].item()),
              flush=True)

        isExists = os.path.exists(os.path.join(opt.result, opt.name))
        if not isExists:
            os.makedirs(os.path.join(opt.result, opt.name))

        save_images(results['gen_B'], pairs,
                    os.path.join(opt.result, opt.name))

        # # print(name1)
        if (step + 1) % 100 == 0:
            save_checkpoint(
                model,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
Example #2
0
def train_gmm(opt, train_loader, model, board):
    if opt.cuda:
        model.cuda()
    model.train()

    # criterion
    criterionL1 = nn.L1Loss()
    
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda step: 1.0 -
            max(0, step - opt.keep_step) / float(opt.decay_step + 1))
    
    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()
        
        if opt.cuda:
            im = inputs['image'].cuda()
            im_pose = inputs['pose_image'].cuda()
            im_h = inputs['head'].cuda()
            shape = inputs['shape'].cuda()
            agnostic = inputs['agnostic'].cuda()
            c = inputs['cloth'].cuda()
            cm = inputs['cloth_mask'].cuda()
            im_c =  inputs['parse_cloth'].cuda()
            im_g = inputs['grid_image'].cuda()
        else:
            im = inputs['image']
            im_pose = inputs['pose_image']
            im_h = inputs['head']
            shape = inputs['shape']
            agnostic = inputs['agnostic']
            c = inputs['cloth']
            cm = inputs['cloth_mask']
            im_c =  inputs['parse_cloth']
            im_g = inputs['grid_image']
            
        grid, theta = model(agnostic, c)
        warped_cloth = F.grid_sample(c, grid, padding_mode='border')
        warped_mask = F.grid_sample(cm, grid, padding_mode='zeros')
        warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')

        visuals = [ [im_h, shape, im_pose], 
                   [c, warped_cloth, im_c], 
                   [warped_grid, (warped_cloth+im)*0.5, im]]
        
        loss = criterionL1(warped_cloth, im_c)  #$$
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
        if (step+1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step+1)
            board.add_scalar('metric', loss.item(), step+1)
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %4f' % (step+1, t, loss.item()), flush=True)

        if (step+1) % opt.save_count == 0:
            save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))
Example #3
0
def main():
    opt = get_opt()
    #     opt.cuda = False
    #     opt.batch_size = 1
    #     opt.name = "test"
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    G = WUTON(opt)
    D = Discriminator()
    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):  # TODO
        load_checkpoint(G, opt.checkpoint)
    train(opt, train_loader, G, D, board)
    # train2(opt, train_loader, G, board)
    save_checkpoint(
        G, os.path.join(opt.checkpoint_dir, opt.name, 'wuton_final.pth'))

    print('Finished training %s, named: %s!' % (opt.stage, opt.name))
Example #4
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_tom(opt, train_loader, model, board)
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)


    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #5
0
def train_tom(opt, train_loader, model, board):
    model.cuda()
    model.train()
    
    # criterion
    criterionL1 = nn.L1Loss()
    criterionVGG = VGGLoss()
    criterionMask = nn.L1Loss()
    
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda step: 1.0 -
            max(0, step - opt.keep_step) / float(opt.decay_step + 1))
    
    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()
            
        im = inputs['image'].cuda()
        im_pose = inputs['pose_image']
        im_h = inputs['head']
        shape = inputs['shape']

        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        
        outputs = model(torch.cat([agnostic, c],1))
        p_rendered, m_composite = torch.split(outputs, 3,1)
        p_rendered = F.tanh(p_rendered)
        m_composite = F.sigmoid(m_composite)
        p_tryon = c * m_composite+ p_rendered * (1 - m_composite)

        visuals = [ [im_h, shape, im_pose], 
                   [c, cm*2-1, m_composite*2-1], 
                   [p_rendered, p_tryon, im]]
            
        loss_l1 = criterionL1(p_tryon, im)
        loss_vgg = criterionVGG(p_tryon, im)
        loss_mask = criterionMask(m_composite, cm)
        loss = loss_l1 + loss_vgg + loss_mask
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
        if (step+1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step+1)
            board.add_scalar('metric', loss.item(), step+1)
            board.add_scalar('L1', loss_l1.item(), step+1)
            board.add_scalar('VGG', loss_vgg.item(), step+1)
            board.add_scalar('MaskL1', loss_mask.item(), step+1)
            #board.add_graph(model, torch.cat([agnostic, c],1))
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f' 
                    % (step+1, t, loss.item(), loss_l1.item(), 
                    loss_vgg.item(), loss_mask.item()), flush=True)

        if (step+1) % opt.save_count == 0:
            save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))
Example #6
0
def main():
    opt = get_opt()
    print(opt)

    print('Loading dataset')
    dataset_train = TOMDataset(opt, mode='train', data_list='train_pairs.txt')
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=opt.batch_size,
                                  num_workers=opt.n_worker,
                                  shuffle=True)
    dataset_val = TOMDataset(opt,
                             mode='val',
                             data_list='val_pairs.txt',
                             train=False)
    dataloader_val = DataLoader(dataset_val,
                                batch_size=opt.batch_size,
                                num_workers=opt.n_worker,
                                shuffle=True)

    save_dir = os.path.join(opt.out_dir, opt.name)
    log_dir = os.path.join(opt.out_dir, 'log')
    dirs = [opt.out_dir, save_dir, os.path.join(save_dir, 'train'), log_dir]
    for d in dirs:
        mkdir(d)
    log_name = os.path.join(log_dir, opt.name + '.csv')
    with open(log_name, 'w') as f:
        f.write('epoch,train_loss,val_loss\n')

    print('Building TOM model')
    gen = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
    dis = NLayerDiscriminator(28,
                              ndf=64,
                              n_layers=6,
                              norm_layer=nn.InstanceNorm2d,
                              use_sigmoid=True)
    gen.cuda()
    dis.cuda()
    n_step = int(opt.n_epoch * len(dataset_train) / opt.batch_size)
    trainer = TOMTrainer(gen, dis, dataloader_train, dataloader_val,
                         opt.gpu_id, opt.log_freq, save_dir, n_step)

    print('Start training TOM')
    for epoch in tqdm(range(opt.n_epoch)):
        print('Epoch: {}'.format(epoch))
        loss = trainer.train(epoch)
        print('Train loss: {:.3f}'.format(loss))
        with open(log_name, 'a') as f:
            f.write('{},{:.3f},'.format(epoch, loss))
        save_checkpoint(
            gen, os.path.join(save_dir, 'gen_epoch_{:02}.pth'.format(epoch)))
        save_checkpoint(
            dis, os.path.join(save_dir, 'dis_epoch_{:02}.pth'.format(epoch)))

        loss = trainer.val(epoch)
        print('Validation loss: {:.3f}'.format(loss))
        with open(log_name, 'a') as f:
            f.write('{:.3f}\n'.format(loss))
    print('Finish training TOM')
Example #7
0
def train_identity_embedding(opt, train_loader, model, board):
    model.cuda()
    model.train()

    # criterion
    mse_criterion = torch.nn.MSELoss()
    triplet_criterion = torch.nn.TripletMarginLoss(margin=0.3)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    pbar = range(opt.keep_step + opt.decay_step)
    if single_gpu_flag(opt):
        pbar = tqdm(pbar)

    for step in pbar:
        inputs_1, inputs_2 = train_loader.next_batch()

        img_1 = inputs_1['cloth'].cuda()
        img_ou_1 = inputs_1['image'].cuda()
        img_2 = inputs_2['cloth'].cuda()
        img_ou_2 = inputs_2['image'].cuda()

        pred_prod_embedding_1, pred_outfit_embedding_1 = model(img_1, img_ou_1)
        pred_prod_embedding_2, pred_outfit_embedding_2 = model(img_2, img_ou_2)

        # msee loss
        mean_squared_loss = (
            mse_criterion(pred_outfit_embedding_1, pred_prod_embedding_1) +
            mse_criterion(pred_outfit_embedding_2, pred_prod_embedding_2)) / 2

        # triplet loss
        triplet_loss = triplet_criterion(pred_outfit_embedding_1, pred_prod_embedding_1, pred_outfit_embedding_2) + \
                       triplet_criterion(pred_outfit_embedding_2, pred_prod_embedding_2, pred_outfit_embedding_1) + \
                       triplet_criterion(pred_outfit_embedding_1, pred_prod_embedding_1, pred_prod_embedding_2) + \
                       triplet_criterion(pred_outfit_embedding_2, pred_prod_embedding_2, pred_prod_embedding_1)

        loss = mean_squared_loss + triplet_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if single_gpu_flag(opt):
            board.add_scalar('metric', loss.item(), step + 1)
            board.add_scalar('MSE', mean_squared_loss.item(), step + 1)
            board.add_scalar('trip', triplet_loss.item(), step + 1)

            pbar.set_description(
                'step: %8d, loss: %.4f, mse: %.4f, trip: %.4f' %
                (step + 1, loss.item(), mean_squared_loss.item(),
                 triplet_loss.item()))

        if (step + 1) % opt.save_count == 0 and single_gpu_flag(opt):
            save_checkpoint(
                model,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
Example #8
0
def train_gmm(opt, train_loader, model, board):
    model.cuda()
    model.train()

    # criterion
    criterionL1 = nn.L1Loss()

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        c_image = inputs['c_image'].cuda()
        t_pose = inputs['t_pose'].cuda()
        c_head = inputs['c_head'].cuda()
        Pre_target_mask = inputs['Pre_target_mask'].cuda()
        cloth = inputs['cloth'].cuda()
        cloth_mask = inputs['cloth_mask'].cuda()
        t_upper_mask = inputs['t_upper_mask'].cuda()
        t_upper_cloth = inputs['t_upper_cloth'].cuda()
        im_g = inputs['grid_image'].cuda()
        # agnostic = inputs['agnostic'].cuda()
        agnostic = torch.cat([Pre_target_mask, t_pose], 1)
        grid, theta = model(agnostic, cloth)
        warped_cloth = F.grid_sample(cloth, grid, padding_mode='border')
        warped_mask = F.grid_sample(cloth_mask, grid, padding_mode='zeros')
        warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')
        visuals = [[c_head, Pre_target_mask, t_pose],
                   [cloth_mask, warped_cloth, t_upper_mask],
                   [warped_grid, (warped_cloth + c_image) * 0.5, c_image]]

        loss = criterionL1(warped_cloth, t_upper_cloth)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step + 1) % opt.display_count == 0:

            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %4f' %
                  (step + 1, t, loss.item()),
                  flush=True)

        if (step + 1) % opt.save_count == 0:
            save_checkpoint(
                model,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
Example #9
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt, 1)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'HPM':
        model = UnetGenerator(25,
                              16,
                              6,
                              ngf=64,
                              norm_layer=nn.InstanceNorm2d,
                              clsf=True)
        d_g = Discriminator_G(opt, 16)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            if not os.path.isdir(opt.checkpoint):
                raise NotImplementedError(
                    'checkpoint should be dir, not file: %s' % opt.checkpoint)
            load_checkpoints(model, d_g, os.path.join(opt.checkpoint,
                                                      "%s.pth"))
        train_hpm(opt, train_loader, model, d_g, board)
        save_checkpoints(
            model, d_g,
            os.path.join(opt.checkpoint_dir,
                         opt.stage + '_' + opt.name + "_final", '%s.pth'))
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 3, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #10
0
def train_mask_gen(opt, train_loader, model):
    model.cuda()
    model.train()

    # criterion
    criterionL1 = nn.L1Loss()

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        c = inputs['cloth'].cuda()

        mesh = inputs['mesh'].cuda()
        pose_map = inputs['pose_map'].cuda()
        person_parse = inputs['person_parse'].cuda()

        outputs = model(torch.cat([mesh, pose_map, c], 1))
        #p_rendered, m_composite = torch.split(outputs, 3,1)
        #p_rendered = F.tanh(p_rendered)
        m_composite = F.sigmoid(outputs)
        #p_tryon = c * m_composite + p_rendered * (1 - m_composite)

        loss_l1 = criterionL1(m_composite, person_parse)

        loss = loss_l1
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step + 1) % opt.display_count == 0:

            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %.4f' %
                  (step + 1, t, loss.item()),
                  flush=True)

        if (step + 1) % opt.save_count == 0:
            save_checkpoint(
                model,
                os.path.join('./mask_gen/', 'model_',
                             'step_%06d.pth' % (step + 1)))

    print('Finished in ' + str(time.time() - start_time))
Example #11
0
def train_gmm(opt, train_loader, model, board):
    model.cuda()
    model.train()

    # criterion
    criterionL1 = nn.L1Loss()

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))
    #change
    loss_sum = 0
    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        im = inputs['image'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        wm = inputs['warped_mask'].cuda()
        im_c = inputs['parse_cloth'].cuda()
        im_g = inputs['grid_image'].cuda()

        grid, theta = model(wm, c)
        warped_cloth = F.grid_sample(c, grid, padding_mode='border')
        warped_mask = F.grid_sample(cm, grid, padding_mode='zeros')
        warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')

        loss = criterionL1(warped_mask, wm)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()

        if (step + 1) % opt.display_count == 0:
            board.add_scalar('metric', loss.item(), step + 1)
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %.4f' %
                  (step + 1, t, loss.item()),
                  flush=True)

        if (step + 1) % opt.save_count == 0:
            save_checkpoint(
                model,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
def main():
    opt = get_opt()
    print(opt)

    print('Loading dataset')
    dataset_train = GMMDataset(opt, mode='train', data_list='train_pairs.txt')
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=opt.batch_size,
                                  num_workers=opt.n_worker,
                                  shuffle=True)
    dataset_val = GMMDataset(opt,
                             mode='val',
                             data_list='val_pairs.txt',
                             train=False)
    dataloader_val = DataLoader(dataset_val,
                                batch_size=opt.batch_size,
                                num_workers=opt.n_worker,
                                shuffle=True)

    save_dir = os.path.join(opt.out_dir, opt.name)
    log_dir = os.path.join(opt.out_dir, 'log')
    dirs = [opt.out_dir, save_dir, os.path.join(save_dir, 'train'), log_dir]
    for d in dirs:
        mkdir(d)
    log_name = os.path.join(log_dir, opt.name + '.csv')
    with open(log_name, 'w') as f:
        f.write('epoch,train_loss,val_loss\n')

    print('Building GMM model')
    model = GMM(opt)
    model.cuda()
    trainer = GMMTrainer(model, dataloader_train, dataloader_val, opt.gpu_id,
                         opt.log_freq, save_dir)

    print('Start training GMM')
    for epoch in tqdm(range(opt.n_epoch)):
        print('Epoch: {}'.format(epoch))
        loss = trainer.train(epoch)
        print('Train loss: {:.3f}'.format(loss))
        with open(log_name, 'a') as f:
            f.write('{},{:.3f},'.format(epoch, loss))
        save_checkpoint(
            model, os.path.join(save_dir, 'epoch_{:02}.pth'.format(epoch)))

        loss = trainer.val(epoch)
        print('Validation loss: {:.3f}'.format(loss))
        with open(log_name, 'a') as f:
            f.write('{:.3f}\n'.format(loss))
    print('Finish training GMM')
Example #13
0
def main():
    opt = get_opt()
    print(opt)
    print("named: %s!" % (opt.name))
    # create dataset
    train_dataset = CPDataset(opt)
    generator = cyclegan(opt)
    print('/////////////////////////////////')
    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)
    print('/////////////////////////////////')
    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
        load_checkpoint(generator, opt.checkpoint)
        print('/////////////////////////////////')
    train_tryon(opt, train_loader, generator)
    save_checkpoint(generator,
                    os.path.join(opt.checkpoint_dir, opt.name, 'try_on.pth'))
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = get_dataset_class(opt.dataset)(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    board = None
    if opt.tensorboard_dir and not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == "GMM":
        model = GMM(opt)
        model.opt = opt
        if not opt.checkpoint == "" and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)

        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, "gmm_final.pth"))
    elif opt.stage == "TOM":
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.opt = opt
        if not opt.checkpoint == "" and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)

        train_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, "tom_final.pth"))
    else:
        raise NotImplementedError("Model [%s] is not implemented" % opt.stage)

    print("Finished training %s, nameed: %s!" % (opt.stage, opt.name))
Example #15
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)
    print('//////////////////////')

    # visualization
    # if not os.path.exists(opt.tensorboard_dir):
    #     os.makedirs(opt.tensorboard_dir)
    # # board = SummaryWriter(log_dir = 'G:/work 2/Codes/cp-vton-master/tensorboard')
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'PGP':
        model = PGP(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_pgp(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'pgp_final.pth'))
    elif opt.stage == 'GMM':
        model = GMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))

    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #16
0
def train_tom(opt, train_loader, model, board):
    # load model
    model.cuda()
    model.train()
    
    # criterion
    criterionL1 = nn.L1Loss()
    criterionVGG = VGGLoss()
    criterionMask = nn.L1Loss()
    
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda step: 1.0 -
            max(0, step - opt.keep_step) / float(opt.decay_step + 1))

    # train log
    if not opt.checkpoint == '':
        train_log = open(os.path.join(opt.checkpoint_dir, opt.name, 'train_log.txt'), 'a')
    else:
        os.makedirs(os.path.join(opt.checkpoint_dir, opt.name), exist_ok=True)
        train_log = open(os.path.join(opt.checkpoint_dir, opt.name, 'train_log.txt'), 'w')
        train_log.write('='*30 + ' Training Option ' + '='*30 + '\n')
        train_log.write(str(opt) + '\n\n')
        train_log.write('='*30 + ' Network Architecture ' + '='*30 + '\n')
        print(str(model) + '\n', file=train_log)
        train_log.write('='*30 + ' Training Log ' + '='*30 + '\n')

    # train loop
    checkpoint_step = 0
    if not opt.checkpoint == '':
        checkpoint_step += int(opt.checkpoint.split('/')[-1][5:11])
    for step in range(checkpoint_step, opt.keep_step + opt.decay_step):
        iter_start_time = time.time()

        dl_iter = iter(train_loader)
        inputs = dl_iter.next()
            
        im = inputs['image'].cuda()
        im_pose = inputs['pose_image']
        im_h = inputs['head']
        shape = inputs['shape']

        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        
        outputs = model(torch.cat([agnostic, c],1))
        p_rendered, m_composite = torch.split(outputs, 3,1)
        p_rendered = torch.tanh(p_rendered)
        m_composite = torch.sigmoid(m_composite)
        p_tryon = c * m_composite+ p_rendered * (1 - m_composite)

        visuals = [ [im_h, shape, im_pose], 
                   [c, cm*2-1, m_composite*2-1], 
                   [p_rendered, p_tryon, im]]
            
        loss_l1 = criterionL1(p_tryon, im)
        loss_vgg = criterionVGG(p_tryon, im)
        loss_mask = criterionMask(m_composite, cm)
        loss = loss_l1 + loss_vgg + loss_mask
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
            
        if (step+1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step+1)
            board.add_scalar('metric', loss.item(), step+1)
            board.add_scalar('L1', loss_l1.item(), step+1)
            board.add_scalar('VGG', loss_vgg.item(), step+1)
            board.add_scalar('MaskL1', loss_mask.item(), step+1)
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f' 
                    % (step+1, t, loss.item(), loss_l1.item(), 
                    loss_vgg.item(), loss_mask.item()), flush=True)
            train_log.write('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f' 
                    % (step+1, t, loss.item(), loss_l1.item(), loss_vgg.item(), loss_mask.item()) + '\n')

        if (step+1) % opt.save_count == 0:
            save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))
def train_gmm(opt, train_loader, model, board):
    model.cuda()
    model.train()

    # criterion
    criterionL1 = nn.L1Loss()
    gicloss = GicLoss(opt)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        im = inputs['image'].cuda()
        im_pose = inputs['pose_image'].cuda()
        im_h = inputs['head'].cuda()
        shape = inputs['shape'].cuda()
        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        im_c = inputs['parse_cloth'].cuda()
        im_g = inputs['grid_image'].cuda()

        gridtps, thetatps, Ipersp, gridaffine, thetaaffine = model(agnostic, c)
        warped_cloth = F.grid_sample(Ipersp, gridtps, padding_mode='border')  #

        warped_maskaffine = F.grid_sample(cm, gridaffine, padding_mode='zeros')
        warped_mask = F.grid_sample(warped_maskaffine,
                                    gridtps,
                                    padding_mode='zeros')

        warped_gridaffine = F.grid_sample(im_g,
                                          gridaffine,
                                          padding_mode='zeros')
        warped_grid = F.grid_sample(warped_gridaffine,
                                    gridtps,
                                    padding_mode='zeros')

        visuals = [[im_h, shape, im_pose], [c, warped_cloth, im_c],
                   [warped_grid, (warped_cloth * 0.7 + im * 0.3), im],
                   [Ipersp, (Ipersp * 0.7 + im * 0.3), warped_gridaffine]]

        Lwarp = criterionL1(warped_cloth, im_c)
        Lpersp = criterionL1(Ipersp, im_c)

        LgicTPS = gicloss(gridtps)
        LgicTPS = LgicTPS / (gridtps.shape[0] * gridtps.shape[1] *
                             gridtps.shape[2])
        # shape of grid: N, Hout, Wout,2: N x 5 x 5 x 2
        # grid shape: N x 5 x 5 x 2

        loss = 0.5 * Lpersp + 0.5 * Lwarp  # + 40 * LgicTPS
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step + 1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step + 1)
            board.add_scalar('metric', loss.item(), step + 1)
            board.add_scalar('0.5*Lpersp', (0.5 * Lpersp).item(), step + 1)
            #board.add_scalar('40*LgicTPS', (40*LgicTPS).item(), step+1)
            board.add_scalar('0.5*Lwarp', (0.5 * Lwarp).item(), step + 1)
            t = time.time() - iter_start_time
            # print('step: %8d, time: %.3f, loss: %4f, (40*LgicTPS): %.8f, 0.7*Lpersp: %.6f, 0.3*Lwarp: %.6f' % (step+1, t, loss.item(), (40*LgicTPS).item(), (0.7*Lpersp).item(), (0.3*Lwarp).item()), flush=True)
            print(
                'step: %8d, time: %.3f, loss: %4f, 0.5*Lpersp: %.6f, 0.5*Lwarp: %.6f'
                % (step + 1, t, loss.item(), (0.5 * Lpersp).item(),
                   (0.5 * Lwarp).item()),
                flush=True)

        if (step + 1) % opt.save_count == 0:
            save_checkpoint(
                model,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        if opt.model == 'RefinedGMM':
            model = RefinedGMM(opt)
        elif opt.model == 'OneRefinedGMM':
            model = OneRefinedGMM(opt)
        else:
            raise TypeError()
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_refined_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model,
            os.path.join(opt.checkpoint_dir, opt.name,
                         'refined_gmm_final.pth'))
    elif opt.stage == 'VariGMM':
        model = VariGMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_refined_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model,
            os.path.join(opt.checkpoint_dir, opt.name,
                         'refined_gmm_final.pth'))
    elif opt.stage == 'semanticGMM':
        model = RefinedGMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_semantic_parsing_refined_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model,
            os.path.join(opt.checkpoint_dir, opt.name,
                         'refined_gmm_final.pth'))
    elif opt.stage == 'no_background_GMM':
        model = RefinedGMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_no_background_refined_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model,
            os.path.join(opt.checkpoint_dir, opt.name,
                         'no_background_refined_gmm_final.pth'))
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    elif opt.stage == 'DeepTom':
        norm_layer = 'instance'
        use_dropout = True
        with_tanh = False
        model = Define_G(25,
                         4,
                         64,
                         'treeresnet',
                         'instance',
                         True,
                         'normal',
                         0.02,
                         opt.gpu_ids,
                         with_tanh=False)
        train_deep_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
def train_refined_gmm(opt, train_loader, model, board):
    model.cuda()
    model.train()

    loss_weight = opt.loss_weight
    # if loss_weight > 0.01:
    #     print("Error")
    #     assert False

    # criterion
    warped_criterionL1 = nn.L1Loss()
    result_criterionL1 = nn.L1Loss()
    point_criterionL1 = nn.L1Loss()
    criterionMask = nn.L1Loss()
    criterionVGG = VGGLoss()
    criterionGram = GramLoss()
    rendered_criterionL1 = nn.L1Loss()

    center_mask_critetionL1 = nn.L1Loss()

    warped_mask_criterionL1 = nn.L1Loss()

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        im = inputs['image'].cuda()
        im_pose = inputs['pose_image'].cuda()
        im_h = inputs['head'].cuda()
        shape = inputs['shape'].cuda()
        densepose_shape = inputs['densepose_shape'].cuda()
        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        im_c = inputs['parse_cloth'].cuda()
        im_g = inputs['grid_image'].cuda()
        parse_cloth_mask = inputs['parse_cloth_mask'].cuda()
        target_shape = inputs['target_shape']

        c_point_plane = inputs['cloth_points'].cuda()
        p_point_plane = inputs['person_points'].cuda()

        grid, theta, warped_cloth, outputs = model(agnostic, c)
        #warped_cloth = F.grid_sample(c, grid, padding_mode='border')
        warped_mask = F.grid_sample(cm, grid, padding_mode='zeros')
        warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')
        compute_c_point_plane = compute_grid_point(p_point_plane, grid)

        warped_mask_loss = 0
        if opt.add_warped_mask_loss:
            warped_mask_loss += warped_mask_criterionL1(
                warped_mask, target_shape)

        c_rendered, m_composite = torch.split(outputs, 3, 1)
        c_rendered = F.tanh(c_rendered)
        m_composite = F.sigmoid(m_composite)
        c_result = warped_cloth * m_composite + c_rendered * (1 - m_composite)

        visuals = [[im_h, shape, im_pose], [c, warped_cloth, im_c],
                   [warped_grid, (warped_cloth + im) * 0.5, im],
                   [m_composite, (c_result + im) * 0.5, c_result]]

        loss_warped_cloth = warped_criterionL1(warped_cloth, im_c)
        loss_point = 0
        if opt.add_point_loss:
            loss_point = point_criterionL1(compute_c_point_plane,
                                           c_point_plane)
        loss_c_result = result_criterionL1(c_result, im_c)
        loss_mask = criterionMask(m_composite, warped_mask)
        loss_vgg = 0
        if opt.add_vgg_loss:
            loss_vgg = criterionVGG(c_result, im_c)
        loss_gram = 0
        if opt.add_gram_loss:
            loss_gram += criterionGram(c_result, im_c)

        loss_render = 0
        if opt.add_render_loss:
            loss_render += rendered_criterionL1(c_rendered, im_c)

        loss_mask_constrain = 0
        if opt.add_mask_constrain:
            center_mask = m_composite * parse_cloth_mask
            ground_mask = torch.ones_like(parse_cloth_mask, dtype=torch.float)
            ground_mask = ground_mask * warped_mask * parse_cloth_mask
            loss_mask_constrain = center_mask_critetionL1(
                center_mask, ground_mask)
            #print("long_mask_constrain = ", loss_mask_constrain)
            loss_mask_constrain = loss_mask_constrain * opt.mask_constrain_weight
            #print("long_mask_constrain = ", loss_mask_constrain)
        # print("loss cloth = ", loss_warped_cloth)
        # print("loss point = ", loss_point)
        # print("loss render = ", loss_render)
        # print("loss_c_result = ", loss_c_result)

        loss = loss_warped_cloth + loss_weight * loss_point + loss_c_result + loss_mask + loss_vgg + loss_render + loss_mask_constrain + warped_mask_loss + loss_gram

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step + 1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step + 1)
            board.add_scalar('metric', loss.item(), step + 1)
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %4f' %
                  (step + 1, t, loss.item()),
                  flush=True)

        if (step + 1) % opt.save_count == 0:
            save_checkpoint(
                model,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
Example #20
0
def train_gmm(opt, train_loader, model, board):
    model.cuda()
    model.train()

    # criterion
    criterionL1 = nn.L1Loss()
    gicloss = GicLoss(opt)
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        im = inputs['image'].cuda()
        im_pose = inputs['pose_image'].cuda()
        im_h = inputs['head'].cuda()
        shape = inputs['shape'].cuda()
        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        im_c = inputs['parse_cloth'].cuda()
        im_g = inputs['grid_image'].cuda()

        grid, theta = model(agnostic,
                            cm)  # can be added c too for new training
        warped_cloth = F.grid_sample(c, grid, padding_mode='border')
        warped_mask = F.grid_sample(cm, grid, padding_mode='zeros')
        warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')

        visuals = [[im_h, shape, im_pose], [c, warped_cloth, im_c],
                   [warped_grid, (warped_cloth + im) * 0.5, im]]

        Lwarp = criterionL1(warped_cloth, im_c)  # loss for warped cloth

        # grid regularization loss
        Lgic = gicloss(grid)
        # 200x200 = 40.000 * 0.001
        Lgic = Lgic / (grid.shape[0] * grid.shape[1] * grid.shape[2])

        loss = Lwarp + 40 * Lgic  # total GMM loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step + 1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step + 1)
            board.add_scalar('loss', loss.item(), step + 1)
            board.add_scalar('40*Lgic', (40 * Lgic).item(), step + 1)
            board.add_scalar('Lwarp', Lwarp.item(), step + 1)
            t = time.time() - iter_start_time
            print(
                'step: %8d, time: %.3f, loss: %4f, (40*Lgic): %.8f, Lwarp: %.6f'
                % (step + 1, t, loss.item(), (40 * Lgic).item(), Lwarp.item()),
                flush=True)

        if (step + 1) % opt.save_count == 0:
            save_checkpoint(
                model,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
def train_gmm(opt, train_loader, model, board):
    model.cuda()
    model.train()

    # criterion
    criterionL1 = nn.L1Loss()
    gicloss = GicLoss(opt)
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        cloth_mask = inputs['cloth_mask'].cuda()
        cloth_mask_gt = inputs['cloth_mask_gt'].cuda()
        shape = inputs['shape'].cuda()
        c_name = inputs['c_name']
        im_names = inputs['im_names']
        im_namet = inputs['im_namet']
        im_g = inputs['grid_image'].cuda()
        parse_im = inputs['image'].cuda()
        cloth = inputs['c'].cuda()

        grid, theta = model(cloth_mask, shape)

        warped_mask = F.grid_sample(cloth_mask, grid, padding_mode='zeros')
        warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')
        warped_cloth = F.grid_sample(cloth, grid, padding_mode='zeros')

        visuals = [[im_g, cloth_mask * 2 - 1, shape],
                   [cloth, warped_cloth, parse_im],
                   [warped_grid, warped_mask * 2 - 1, cloth_mask_gt * 2 - 1]]

        Lwarp = criterionL1(warped_mask, cloth_mask_gt)

        Lgic = gicloss(grid)
        Lgic = Lgic / (grid.shape[0] * grid.shape[1] * grid.shape[2]
                       )  #200x200 = 40.000 * 0.001
        loss = 0.0244 * Lwarp + 0.9756 * Lgic

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step + 1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step + 1)
            board.add_scalar('metric', loss.item(), step + 1)
            board.add_scalar('0.9756*Lgic', (0.9756 * Lgic).item(), step + 1)
            board.add_scalar('0.0244*Lwarp', (0.0244 * Lwarp).item(), step + 1)
            t = time.time() - iter_start_time
            print(
                'step: %8d, time: %.3f, loss: %4f, (0.9756*Lgic): %.6f, 0.0244*Lwarp: %.4f'
                % (step + 1, t, loss.item(), (0.9756 * Lgic).item(),
                   (0.0244 * Lwarp).item()),
                flush=True)
        if (step + 1) % opt.save_count == 0:
            save_checkpoint(
                model,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
Example #22
0
def train_residual_old(opt,
                       train_loader,
                       model,
                       model_module,
                       gmm_model,
                       generator,
                       image_embedder,
                       board,
                       discriminator=None,
                       discriminator_module=None):

    lambdas_vis_reg = {'l1': 1.0, 'prc': 0.05, 'style': 100.0}
    lambdas = {
        'adv': 0.1,
        'identity': 1000,
        'match_gt': 50,
        'vis_reg': .1,
        'consist': 50
    }

    model.train()
    gmm_model.eval()
    image_embedder.eval()
    generator.eval()

    # criterion
    l1_criterion = nn.L1Loss()
    mse_criterion = nn.MSELoss()
    vgg_extractor = VGGExtractor().cuda().eval()
    adv_criterion = utils.AdversarialLoss('lsgan').cuda()

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999),
                                 weight_decay=1e-4)
    if opt.use_gan:
        D_optim = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999),
                                   weight_decay=1e-4)

    pbar = range(opt.keep_step + opt.decay_step)
    if single_gpu_flag(opt):
        pbar = tqdm(pbar)

    for step in pbar:
        iter_start_time = time.time()
        inputs, inputs_2 = train_loader.next_batch()

        im = inputs['image'].cuda()
        im_pose = inputs['pose_image']
        im_h = inputs['head']
        shape = inputs['shape']

        agnostic = inputs['agnostic'].cuda()

        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        c_2 = inputs_2['cloth'].cuda()
        cm_2 = inputs_2['cloth_mask'].cuda()

        with torch.no_grad():
            grid, theta = gmm_model(agnostic, c)
            c = F.grid_sample(c, grid, padding_mode='border')
            cm = F.grid_sample(cm, grid, padding_mode='zeros')

            grid_2, theta_2 = gmm_model(agnostic, c_2)
            c_2 = F.grid_sample(c_2, grid_2, padding_mode='border')
            cm_2 = F.grid_sample(cm_2, grid_2, padding_mode='zeros')

            outputs = generator(torch.cat([agnostic, c], 1))
            p_rendered, m_composite = torch.split(outputs, 3, 1)
            p_rendered = F.tanh(p_rendered)
            m_composite = F.sigmoid(m_composite)
            transfer_1 = c * m_composite + p_rendered * (1 - m_composite)

            outputs_2 = generator(torch.cat([agnostic, c_2], 1))
            p_rendered_2, m_composite_2 = torch.split(outputs_2, 3, 1)
            p_rendered_2 = F.tanh(p_rendered_2)
            m_composite_2 = F.sigmoid(m_composite_2)
            transfer_2 = c_2 * m_composite_2 + p_rendered_2 * (1 -
                                                               m_composite_2)

        gt_residual = (torch.mean(im, dim=1) -
                       torch.mean(transfer_1, dim=1)).unsqueeze(1)

        output_1 = model(transfer_1.detach(), gt_residual.detach())
        output_2 = model(transfer_2.detach(), gt_residual.detach())

        embedding_1 = image_embedder(output_1)
        embedding_2 = image_embedder(output_2)

        embedding_1_t = image_embedder(transfer_1)
        embedding_2_t = image_embedder(transfer_2)

        if opt.use_gan:
            # train discriminator
            real_L_logit, real_L_cam_logit, real_G_logit, real_G_cam_logit = discriminator(
                im)
            fake_L_logit_1, fake_L_cam_logit_1, fake_G_logit_1, fake_G_cam_logit_1 = discriminator(
                output_1.detach())
            fake_L_logit_2, fake_L_cam_logit_2, fake_G_logit_2, fake_G_cam_logit_2 = discriminator(
                output_2.detach())

            D_true_loss = adv_criterion(real_L_logit, True) + \
                          adv_criterion(real_G_logit, True) + \
                          adv_criterion(real_L_cam_logit, True) + \
                          adv_criterion(real_G_cam_logit, True)
            D_fake_loss = adv_criterion(torch.cat([fake_L_cam_logit_1, fake_L_cam_logit_2], dim=0), False) + \
                          adv_criterion(torch.cat([fake_G_cam_logit_1, fake_G_cam_logit_2], dim=0), False) + \
                          adv_criterion(torch.cat([fake_L_logit_1, fake_L_logit_2], dim=0), False) + \
                          adv_criterion(torch.cat([fake_G_logit_1, fake_G_logit_2], dim=0), False)

            D_loss = D_true_loss + D_fake_loss
            D_optim.zero_grad()
            D_loss.backward()
            D_optim.step()

            # train generator
            fake_L_logit_1, fake_L_cam_logit_1, fake_G_logit_1, fake_G_cam_logit_1 = discriminator(
                output_1)
            fake_L_logit_2, fake_L_cam_logit_2, fake_G_logit_2, fake_G_cam_logit_2 = discriminator(
                output_2)

            G_adv_loss = adv_criterion(torch.cat([fake_L_logit_1, fake_L_logit_2], dim=0), True) + \
                         adv_criterion(torch.cat([fake_G_logit_1, fake_G_logit_2], dim=0), True) + \
                         adv_criterion(torch.cat([fake_L_cam_logit_1, fake_L_cam_logit_2], dim=0), True) + \
                         adv_criterion(torch.cat([fake_G_cam_logit_1, fake_G_cam_logit_2], dim=0), True)

        # identity loss
        identity_loss = mse_criterion(embedding_1,
                                      embedding_1_t) + mse_criterion(
                                          embedding_2, embedding_2_t)

        # vis reg loss
        output_1_feats = vgg_extractor(output_1)
        transfer_1_feats = vgg_extractor(transfer_1)
        output_2_feats = vgg_extractor(output_2)
        transfer_2_feats = vgg_extractor(transfer_2)
        # gt_feats = vgg_extractor(data['image'].cuda())

        style_reg = utils.compute_style_loss(
            output_1_feats,
            transfer_1_feats, l1_criterion) + utils.compute_style_loss(
                output_2_feats, transfer_2_feats, l1_criterion)
        perceptual_reg = utils.compute_perceptual_loss(
            output_1_feats, transfer_1_feats,
            l1_criterion) + utils.compute_perceptual_loss(
                output_2_feats, transfer_2_feats, l1_criterion)
        l1_reg = l1_criterion(output_1, transfer_1) + l1_criterion(
            output_2, transfer_2)

        vis_reg_loss = l1_reg * lambdas_vis_reg[
            "l1"] + style_reg * lambdas_vis_reg[
                "style"] + perceptual_reg * lambdas_vis_reg["prc"]

        # match gt loss
        match_gt_loss = l1_criterion(
            output_1, im
        )  #* lambdas_vis_reg["l1"] + utils.compute_style_loss(output_1_feats, gt_feats, l1_criterion) * lambdas_vis_reg["style"] + utils.compute_perceptual_loss(output_1_feats, gt_feats, l1_criterion) * lambdas_vis_reg["prc"]

        # consistency loss
        consistency_loss = l1_criterion(transfer_1 - output_1,
                                        transfer_2 - output_2)

        visuals = [[im_h, shape, im],
                   [
                       c, c_2,
                       torch.cat([gt_residual, gt_residual, gt_residual],
                                 dim=1)
                   ], [transfer_1, output_1, (output_1 - transfer_1) / 2],
                   [transfer_2, output_2, (output_2 - transfer_2) / 2]]

        total_loss = lambdas['identity'] * identity_loss + \
                     lambdas['match_gt'] * match_gt_loss + \
                     lambdas['vis_reg'] * vis_reg_loss + \
                     lambdas['consist'] * consistency_loss

        if opt.use_gan:
            total_loss += lambdas['adv'] * G_adv_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if single_gpu_flag(opt):
            if (step + 1) % opt.display_count == 0:
                board_add_images(board, str(step + 1), visuals, step + 1)
            board.add_scalar('loss/total', total_loss.item(), step + 1)
            board.add_scalar('loss/identity', identity_loss.item(), step + 1)
            board.add_scalar('loss/vis_reg', vis_reg_loss.item(), step + 1)
            board.add_scalar('loss/match_gt', match_gt_loss.item(), step + 1)
            board.add_scalar('loss/consist', consistency_loss.item(), step + 1)
            if opt.use_gan:
                board.add_scalar('loss/Dadv', D_loss.item(), step + 1)
                board.add_scalar('loss/Gadv', G_adv_loss.item(), step + 1)

            pbar.set_description(
                'step: %8d, loss: %.4f, identity: %.4f, vis_reg: %.4f, match_gt: %.4f, consist: %.4f'
                % (step + 1, total_loss.item(), identity_loss.item(),
                   vis_reg_loss.item(), match_gt_loss.item(),
                   consistency_loss.item()))

        if (step + 1) % opt.save_count == 0 and single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
            if opt.use_gan:
                save_checkpoint(
                    discriminator_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'step_disc_%06d.pth' % (step + 1)))
Example #23
0
def train_hpm(opt, train_loader, model, d_g, board):
    model.cuda()
    model.train()
    d_g.cuda()
    d_g.train()

    dis_label = Variable(torch.FloatTensor(opt.batch_size)).cuda()
    
    # criterion
    criterionMCE = nn.CrossEntropyLoss()#nn.BCEWithLogitsLoss()
    #criterionMSE = nn.MSELoss()
    criterionGAN = nn.BCELoss()
    
    # optimizer
    optimizerG = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    schedulerG = torch.optim.lr_scheduler.LambdaLR(optimizerG, lr_lambda = lambda step: 1.0 -
            max(0, step - opt.keep_step) / float(opt.decay_step + 1))

    optimizerD = torch.optim.Adam(d_g.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    schedulerD = torch.optim.lr_scheduler.LambdaLR(optimizerD, lr_lambda = lambda step: 1.0 -
            max(0, step - opt.keep_step) / float(opt.decay_step + 1))
    
    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        #prep
        inputs = train_loader.next_batch()
            
        im = inputs['image'].cuda()#sz=b*3*256*192
        seg_gt = inputs['seg'].cuda()
        seg_enc = inputs['seg_enc'].cuda()

        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        batch_size = im.size(0)

        optimizerD.zero_grad()
        #D_real
        dis_label.data.fill_(1.0-random.random()*0.1)
        dis_g_output = d_g(seg_enc)
        
        errDg_real = criterionGAN(dis_g_output, dis_label)
        errDg_real.backward()

        #generate image
        segmentation = model(torch.cat([agnostic, c],1))

        #D_fake
        dis_label.data.fill_(0.0+random.random()*0.1)

        dis_g_output = d_g(segmentation.detach())
        errDg_fake = criterionGAN(dis_g_output, dis_label)
        errDg_fake.backward()
        optimizerD.step()


        #model_train
        optimizerG.zero_grad()
        dis_label.data.fill_(1.0-random.random()*0.1)
        dis_g_output = d_g(segmentation)
        errG_fake = criterionGAN(dis_g_output, dis_label)
        loss_mce = criterionMCE(segmentation.view(batch_size, 16, -1), seg_gt.view(batch_size, -1))

        loss = errG_fake + loss_mce * opt.alpha
        loss.backward()
        optimizerG.step()
                    
        if (step+1) % opt.display_count == 0:
            t = time.time() - iter_start_time
            
            loss_dict = {"GAN":errG_fake.item(), "TOT":loss.item(), "MCE":loss_mce.item(), 
                         "DG":((errDg_fake+errDg_real)/2).item()}
            print('step: %d|time: %.3f'%(step+1, t), end="")
            
            for k, v in loss_dict.items():
                print('|%s: %.3f'%(k, v), end="")
                board.add_scalar(k, v, step+1)
            print()
            
        if (step+1) % opt.save_count == 0:
            """
            save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.stage +'_'+ opt.name, 'step_%06d.pth' % (step+1)))
            save_checkpoint(d_g, os.path.join(opt.checkpoint_dir, opt.stage +'_'+ opt.name, 'step_%06d_dg.pth' % (step+1)))
            """
            save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.stage +'_'+ opt.name, 'model.pth'))
            save_checkpoint(d_g, os.path.join(opt.checkpoint_dir, opt.stage +'_'+ opt.name, 'dg.pth'))
Example #24
0
    if opt.shuffle :
        train_sampler = None
    else:
        train_sampler = sampler.RandomSampler(train_dataset)
    train_loader = DataLoader(
                train_dataset, batch_size=opt.batch_size, shuffle=opt.shuffle,
                num_workers=opt.workers, pin_memory=True, sampler=train_sampler)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))
   
    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_tom(opt, train_loader, model, board)
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)
        
  
    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #25
0
def train_gmm(opt, train_loader, model, board):
    model.cuda()
    discriminator = Discriminator()
    discriminator.cuda()

    model.train()
    discriminator.train()

    # criterion
    criterion = nn.BCELoss()
    criterionL1 = nn.L1Loss()
    criterionPSC = PSCLoss()

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))

    #count = 0

    #base_name = os.path.basename(opt.checkpoint)
    # save_dir = os.path.join(opt.result_dir, opt.datamode)
    # if not os.path.exists(save_dir):
    #     os.makedirs(save_dir)
    # warp_cloth_dir = os.path.join(save_dir, 'warp-cloth')
    # if not os.path.exists(warp_cloth_dir):
    #    os.makedirs(warp_cloth_dir)

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        c_names = inputs['c_name']
        im = inputs['image'].cuda()
        im_pose = inputs['pose_image'].cuda()
        im_h = inputs['head'].cuda()
        shape = inputs['shape'].cuda()
        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        im_c = inputs['parse_cloth'].cuda()
        im_g = inputs['grid_image'].cuda()
        blank = inputs['blank'].cuda()

        grid, theta = model(agnostic, c)
        warped_cloth = F.grid_sample(c, grid, padding_mode='border')
        warped_mask = F.grid_sample(cm, grid, padding_mode='zeros')
        warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')

        #if (count < 14222):
        #    save_images(warped_cloth, c_names, warp_cloth_dir)
        #    print(warped_cloth.size()[0])
        #    count+=warped_cloth.size()[0]

        visuals = [[im_h, shape, im_pose], [c, warped_cloth, im_c],
                   [warped_grid, (warped_cloth + im) * 0.5, im]]

        discriminator_train_step(opt.batch_size, discriminator, model,
                                 optimizer_d, criterion, im_c, agnostic, c)
        res, loss, lossL1, lossPSC, lossGAN = generator_train_step(
            opt.batch_size, discriminator, model, optimizer, criterion,
            criterionL1, criterionPSC, blank, im_c, agnostic, c)

        if (step + 1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step + 1)
            board.add_scalar('lossL1', lossL1.item(), step + 1)
            board.add_scalar('lossPSC', lossPSC.item(), step + 1)
            board.add_scalar('lossGAN', lossGAN.item(), step + 1)
            board.add_scalar('loss', loss.item(), step + 1)
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, lossL1: %4f' %
                  (step + 1, t, lossL1.item()),
                  flush=True)
            print('step: %8d, time: %.3f, lossPSC: %4f' %
                  (step + 1, t, lossPSC.item()),
                  flush=True)
            print('step: %8d, time: %.3f, lossGAN: %4f' %
                  (step + 1, t, lossGAN.item()),
                  flush=True)
            print('step: %8d, time: %.3f, loss: %4f' %
                  (step + 1, t, loss.item()),
                  flush=True)

        if (step + 1) % opt.save_count == 0:
            save_checkpoint(
                model,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
Example #26
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'TOM':

        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.cuda()
        # if opt.distributed:
        #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module

        train_tom(opt, train_loader, model, model_module, gmm_model, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    elif opt.stage == 'TOM+WARP':

        gmm_model = GMM(opt)
        gmm_model.cuda()

        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.cuda()
        # if opt.distributed:
        #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        gmm_model_module = gmm_model
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            gmm_model = torch.nn.parallel.DistributedDataParallel(
                gmm_model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            gmm_model_module = gmm_model.module

        train_tom_gmm(opt, train_loader, model, model_module, gmm_model,
                      gmm_model_module, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))

    elif opt.stage == "identity":
        model = Embedder()
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_identity_embedding(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'residual':

        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        generator_model = UnetGenerator(25,
                                        4,
                                        6,
                                        ngf=64,
                                        norm_layer=nn.InstanceNorm2d)
        load_checkpoint(generator_model,
                        "checkpoints/tom_train_new/step_038000.pth")
        generator_model.cuda()

        embedder_model = Embedder()
        load_checkpoint(embedder_model,
                        "checkpoints/identity_train_64_dim/step_020000.pth")
        embedder_model = embedder_model.embedder_b.cuda()

        model = UNet(n_channels=4, n_classes=3)
        if opt.distributed:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.apply(utils.weights_init('kaiming'))
        model.cuda()

        if opt.use_gan:
            discriminator = Discriminator()
            discriminator.apply(utils.weights_init('gaussian'))
            discriminator.cuda()

            acc_discriminator = AccDiscriminator()
            acc_discriminator.apply(utils.weights_init('gaussian'))
            acc_discriminator.cuda()

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
            if opt.use_gan:
                load_checkpoint(discriminator,
                                opt.checkpoint.replace("step_", "step_disc_"))

        model_module = model
        if opt.use_gan:
            discriminator_module = discriminator
            acc_discriminator_module = acc_discriminator

        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            if opt.use_gan:
                discriminator = torch.nn.parallel.DistributedDataParallel(
                    discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                discriminator_module = discriminator.module

                acc_discriminator = torch.nn.parallel.DistributedDataParallel(
                    acc_discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                acc_discriminator_module = acc_discriminator.module

        if opt.use_gan:
            train_residual(opt,
                           train_loader,
                           model,
                           model_module,
                           gmm_model,
                           generator_model,
                           embedder_model,
                           board,
                           discriminator=discriminator,
                           discriminator_module=discriminator_module,
                           acc_discriminator=acc_discriminator,
                           acc_discriminator_module=acc_discriminator_module)

            if single_gpu_flag(opt):
                save_checkpoint(
                    {
                        "generator": model_module,
                        "discriminator": discriminator_module
                    },
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
        else:
            train_residual(opt, train_loader, model, model_module, gmm_model,
                           generator_model, embedder_model, board)
            if single_gpu_flag(opt):
                save_checkpoint(
                    model_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
    elif opt.stage == "residual_old":
        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        generator_model = UnetGenerator(25,
                                        4,
                                        6,
                                        ngf=64,
                                        norm_layer=nn.InstanceNorm2d)
        load_checkpoint(generator_model,
                        "checkpoints/tom_train_new_2/step_070000.pth")
        generator_model.cuda()

        embedder_model = Embedder()
        load_checkpoint(embedder_model,
                        "checkpoints/identity_train_64_dim/step_020000.pth")
        embedder_model = embedder_model.embedder_b.cuda()

        model = UNet(n_channels=4, n_classes=3)
        if opt.distributed:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.apply(utils.weights_init('kaiming'))
        model.cuda()

        if opt.use_gan:
            discriminator = Discriminator()
            discriminator.apply(utils.weights_init('gaussian'))
            discriminator.cuda()

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        if opt.use_gan:
            discriminator_module = discriminator
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            if opt.use_gan:
                discriminator = torch.nn.parallel.DistributedDataParallel(
                    discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                discriminator_module = discriminator.module

        if opt.use_gan:
            train_residual_old(opt,
                               train_loader,
                               model,
                               model_module,
                               gmm_model,
                               generator_model,
                               embedder_model,
                               board,
                               discriminator=discriminator,
                               discriminator_module=discriminator_module)
            if single_gpu_flag(opt):
                save_checkpoint(
                    {
                        "generator": model_module,
                        "discriminator": discriminator_module
                    },
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
        else:
            train_residual_old(opt, train_loader, model, model_module,
                               gmm_model, generator_model, embedder_model,
                               board)
            if single_gpu_flag(opt):
                save_checkpoint(
                    model_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #27
0
def main():
    opt = get_opt()

    if opt.mode == 'test':
        opt.datamode  = "test"
        opt.data_list = "test_pairs.txt"
        opt.shuffle = False
    elif opt.mode == 'val':
        opt.shuffle = False
    elif opt.mode != 'train':
        print(opt.mode)

    print(opt)

    if opt.mode != 'train':
        opt.batch_size = 1


    if opt.mode != 'train' and not opt.checkpoint:
        print("You need to have a checkpoint for: "+opt.mode)
        return None

    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))
   
    # create dataset 
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))
    
    # create model & train & save the final checkpoint
    if opt.stage == 'HPM':
        model = UnetGenerator(25, 16, 6, ngf=64, norm_layer=nn.InstanceNorm2d, clsf=True)
        d_g= Discriminator_G(opt, 16)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
            load_checkpoint(d_g, opt.checkpoint[:-9] + "dg.pth")

        if opt.mode == "train":
            train_hpm(opt, train_loader, model, d_g, board)
        else:
            test_hpm(opt, train_loader, model)

        save_checkpoints(model, d_g, os.path.join(opt.checkpoint_dir, opt.stage +'_'+ opt.name+"_final", '%s.pth'))

    elif opt.stage == 'GMM':
        #seg_unet = UnetGenerator(25, 16, 6, ngf=64, norm_layer=nn.InstanceNorm2d, clsf=True)
        model = GMM(opt, 1)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        
        if opt.mode == "train":
            train_gmm(opt, train_loader, model, board)
        else:
            test_gmm(opt, train_loader, model)
        
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    
    elif opt.stage == 'TOM':
        model = UnetGenerator(31, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        if opt.mode == "train":
            train_tom(opt, train_loader, model, board)
        else:
            test_tom(opt, train_loader, model)
        
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)
  
    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #28
0
def train_gmm(opt, train_loader, model, board):
    # load model
    model.cuda()
    model.train()

    # criterion
    criterionL1 = nn.L1Loss()
    
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda step: 1.0 -
            max(0, step - opt.keep_step) / float(opt.decay_step + 1))
    
    # train log
    if not opt.checkpoint == '':
        train_log = open(os.path.join(opt.checkpoint_dir, opt.name, 'train_log.txt'), 'a')
    else:
        os.makedirs(os.path.join(opt.checkpoint_dir, opt.name), exist_ok=True)
        train_log = open(os.path.join(opt.checkpoint_dir, opt.name, 'train_log.txt'), 'w')
        train_log.write('='*30 + ' Training Option ' + '='*30 + '\n')
        train_log.write(str(opt) + '\n\n')
        train_log.write('='*30 + ' Network Architecture ' + '='*30 + '\n')
        print(str(model) + '\n', file=train_log)
        train_log.write('='*30 + ' Training Log ' + '='*30 + '\n')
    
    # train loop
    checkpoint_step = 0
    if not opt.checkpoint == '':
        checkpoint_step += int(opt.checkpoint.split('/')[-1][5:11])
    for step in range(checkpoint_step, opt.keep_step + opt.decay_step):
        iter_start_time = time.time()

        dl_iter = iter(train_loader)
        inputs = dl_iter.next()
            
        im = inputs['image'].cuda()
        im_pose = inputs['pose_image'].cuda()
        im_h = inputs['head'].cuda()
        shape = inputs['shape'].cuda()
        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        im_c =  inputs['parse_cloth'].cuda()
        im_g = inputs['grid_image'].cuda()
            
        grid, theta = model(agnostic, c)
        warped_cloth = F.grid_sample(c, grid, padding_mode='border', align_corners=True)
        warped_mask = F.grid_sample(cm, grid, padding_mode='zeros', align_corners=True)
        warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros', align_corners=True)

        visuals = [ [im_h, shape, im_pose], 
                   [c, warped_cloth, im_c], 
                   [warped_grid, (warped_cloth+im)*0.5, im]]
        
        loss = criterionL1(warped_cloth, im_c)    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
            
        if (step+1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step+1)
            board.add_scalar('metric', loss.item(), step+1)
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %4f' % (step+1, t, loss.item()), flush=True)
            train_log.write('step: %8d, time: %.3f, loss: %.4f' 
                    % (step+1, t, loss.item()) + '\n')

        if (step+1) % opt.save_count == 0:
            save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))
Example #29
0
def train(opt, train_loader, G, D, board):
    human_parser = HumanParser(opt)
    human_parser.eval()
    G.train()
    D.train()

    # palette = get_palette()

    # Criterion
    criterionWarp = nn.L1Loss()
    criterionPerceptual = VGGLoss()
    criterionL1 = nn.L1Loss()
    BCE_stable = nn.BCEWithLogitsLoss()
    criterionCloth = nn.L1Loss()

    # Variables
    ya = torch.FloatTensor(opt.batch_size)
    yb = torch.FloatTensor(opt.batch_size)
    u = torch.FloatTensor((opt.batch_size, 1, 1, 1))
    grad_outputs = torch.ones(opt.batch_size)

    # Everything cuda
    if opt.cuda:
        G.cuda()
        D.cuda()
        human_parser.cuda()
        criterionWarp = criterionWarp.cuda()
        criterionPerceptual = criterionPerceptual.cuda()
        criterionL1 = criterionL1.cuda()
        BCE_stable.cuda()
        criterionCloth = criterionCloth.cuda()

        ya = ya.cuda()
        yb = yb.cuda()
        u = u.cuda()
        grad_outputs = grad_outputs.cuda()

        # DataParallel
        G = nn.DataParallel(G)
        D = nn.DataParallel(D)
        human_parser = nn.DataParallel(human_parser)

    # Optimizers
    optimizerD = torch.optim.Adam(D.parameters(),
                                  lr=opt.lr,
                                  betas=(0.5, 0.999))
    optimizerG = torch.optim.Adam(G.parameters(),
                                  lr=opt.lr,
                                  betas=(0.5, 0.999))

    # Fitting model
    step_start_time = time.time()
    for step in range(opt.n_iter):
        ########################
        # (1) Update D network #
        ########################

        for p in D.parameters():
            p.requires_grad = True

        for t in range(opt.Diters):
            D.zero_grad()

            inputs = train_loader.next_batch()
            pa = inputs['image'].cuda()
            ap = inputs['agnostic'].cuda()
            cb = inputs['another_cloth'].cuda()
            del inputs

            current_batch_size = pa.size(0)
            ya_pred = D(pa)
            _, pb_fake = G(cb, ap)

            # Detach y_pred_fake from the neural network G and put it inside D
            yb_pred_fake = D(pb_fake.detach())
            ya.data.resize_(current_batch_size).fill_(1)
            yb.data.resize_(current_batch_size).fill_(0)

            errD = (BCE_stable(ya_pred - torch.mean(yb_pred_fake), ya) +
                    BCE_stable(yb_pred_fake - torch.mean(ya_pred), yb)) / 2.0
            errD.backward()

            # Gradient penalty
            with torch.no_grad():
                u.resize_(current_batch_size, 1, 1, 1).uniform_(0, 1)
                grad_outputs.data.resize_(current_batch_size)
            x_both = pa * u + pb_fake * (1. - u)

            # We only want the gradients with respect to x_both
            x_both = Variable(x_both, requires_grad=True)
            grad = torch.autograd.grad(outputs=D(x_both),
                                       inputs=x_both,
                                       grad_outputs=grad_outputs,
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]
            # We need to norm 3 times (over n_colors x image_size x image_size) to get only a vector of size
            # "batch_size"
            grad_penalty = opt.penalty * (
                (grad.norm(2, 1).norm(2, 1).norm(2, 1) - 1)**2).mean()
            grad_penalty.backward()

            optimizerD.step()

        ########################
        # (2) Update G network #
        ########################

        for p in D.parameters():
            p.requires_grad = False

        for t in range(opt.Giters):
            inputs = train_loader.next_batch()
            pa = inputs['image'].cuda()
            ap = inputs['agnostic'].cuda()
            ca = inputs['cloth'].cuda()
            cb = inputs['another_cloth'].cuda()
            parse_cloth = inputs['parse_cloth'].cuda()
            del inputs

            current_batch_size = pa.size(0)

            # paired data
            G.zero_grad()

            warped_cloth_a, pa_fake = G(ca, ap)
            if step >= opt.human_parser_step:  # 生成的图片较真实后再添加human parser
                parse_pa_fake = human_parser(pa_fake)  # (N,H,W)
                parse_ca_fake = (parse_pa_fake == 5) + \
                                (parse_pa_fake == 6) + \
                                (parse_pa_fake == 7)  # [0,1] (N,H,W)
                parse_ca_fake = parse_ca_fake.unsqueeze(1).type_as(
                    pa_fake)  # (N,1,H,W)
                ca_fake = pa_fake * parse_ca_fake + (1 - parse_ca_fake
                                                     )  # [-1,1]
                with torch.no_grad():
                    parse_pa_fake_vis = visualize_seg(parse_pa_fake)
                l_cloth_p = criterionCloth(ca_fake, warped_cloth_a)
            else:
                with torch.no_grad():
                    ca_fake = torch.zeros_like(pa_fake)
                    parse_pa_fake_vis = torch.zeros_like(pa_fake)
                    l_cloth_p = torch.zeros(1).cuda()

            l_warp = 20 * criterionWarp(warped_cloth_a, parse_cloth)
            l_perceptual = criterionPerceptual(pa_fake, pa)
            l_L1 = criterionL1(pa_fake, pa)
            loss_p = l_warp + l_perceptual + l_L1 + l_cloth_p

            loss_p.backward()
            optimizerG.step()

            # unpaired data
            G.zero_grad()

            warped_cloth_b, pb_fake = G(cb, ap)
            if step >= opt.human_parser_step:  # 生成的图片较真实后再添加human parser
                parse_pb_fake = human_parser(pb_fake)
                parse_cb_fake = (parse_pb_fake == 5) + \
                                (parse_pb_fake == 6) + \
                                (parse_pb_fake == 7)  # [0,1] (N,H,W)
                parse_cb_fake = parse_cb_fake.unsqueeze(1).type_as(
                    pb_fake)  # (N,1,H,W)
                cb_fake = pb_fake * parse_cb_fake + (1 - parse_cb_fake
                                                     )  # [-1,1]
                with torch.no_grad():
                    parse_pb_fake_vis = visualize_seg(parse_pb_fake)
                l_cloth_up = criterionCloth(cb_fake, warped_cloth_b)
            else:
                with torch.no_grad():
                    cb_fake = torch.zeros_like(pb_fake)
                    parse_pb_fake_vis = torch.zeros_like(pb_fake)
                    l_cloth_up = torch.zeros(1).cuda()

            with torch.no_grad():
                ya.data.resize_(current_batch_size).fill_(1)
                yb.data.resize_(current_batch_size).fill_(0)
            ya_pred = D(pa)
            yb_pred_fake = D(pb_fake)

            # Non-saturating
            l_adv = 0.1 * (
                BCE_stable(ya_pred - torch.mean(yb_pred_fake), yb) +
                BCE_stable(yb_pred_fake - torch.mean(ya_pred), ya)) / 2
            loss_up = l_adv + l_cloth_up
            loss_up.backward()
            optimizerG.step()

            # visuals = [
            #     [cb, warped_cloth_b, pb_fake],
            #     [ca, warped_cloth_a, pa_fake],
            #     [ap, parse_cloth, pa]
            # ]
            visuals = [[
                cb, warped_cloth_b, pb_fake, cb_fake, parse_pb_fake_vis
            ], [ca, warped_cloth_a, pa_fake, ca_fake, parse_pa_fake_vis],
                       [ap, parse_cloth, pa]]

            if (step + 1) % opt.display_count == 0:
                board_add_images(board, 'combine', visuals, step + 1)
                board.add_scalar('loss_p', loss_p.item(), step + 1)
                board.add_scalar('l_warp', l_warp.item(), step + 1)
                board.add_scalar('l_perceptual', l_perceptual.item(), step + 1)
                board.add_scalar('l_L1', l_L1.item(), step + 1)
                board.add_scalar('l_cloth_p', l_cloth_p.item(), step + 1)
                board.add_scalar('loss_up', loss_up.item(), step + 1)
                board.add_scalar('l_adv', l_adv.item(), step + 1)
                board.add_scalar('l_cloth_up', l_cloth_up.item(), step + 1)
                board.add_scalar('errD', errD.item(), step + 1)

                t = time.time() - step_start_time
                print(
                    'step: %8d, time: %.3f, loss_p: %4f, loss_up: %.4f, l_adv: %.4f, errD: %.4f'
                    % (step + 1, t, loss_p.item(), loss_up.item(),
                       l_adv.item(), errD.item()),
                    flush=True)
                step_start_time = time.time()

            if (step + 1) % opt.save_count == 0:
                save_checkpoint(
                    G,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'step_%06d.pth' % (step + 1)))
Example #30
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    gmm_model = GMM(opt)
    load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
    gmm_model.cuda()

    generator_model = UnetGenerator(25,
                                    4,
                                    6,
                                    ngf=64,
                                    norm_layer=nn.InstanceNorm2d)
    load_checkpoint(generator_model,
                    "checkpoints/tom_train_new_2/step_040000.pth")
    generator_model.cuda()

    embedder_model = Embedder()
    load_checkpoint(embedder_model,
                    "checkpoints/identity_train_64_dim/step_020000.pth")
    embedder_model = embedder_model.embedder_b.cuda()

    model = G()
    model.apply(utils.weights_init('kaiming'))
    model.cuda()

    if opt.use_gan:
        discriminator = Discriminator()
        discriminator.apply(utils.weights_init('gaussian'))
        discriminator.cuda()

    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
        load_checkpoint(model, opt.checkpoint)

    model_module = model
    if opt.use_gan:
        discriminator_module = discriminator
    if opt.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True)
        model_module = model.module
        if opt.use_gan:
            discriminator = torch.nn.parallel.DistributedDataParallel(
                discriminator,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            discriminator_module = discriminator.module

    if opt.use_gan:
        train_residual_old(opt,
                           train_loader,
                           model,
                           model_module,
                           gmm_model,
                           generator_model,
                           embedder_model,
                           board,
                           discriminator=discriminator,
                           discriminator_module=discriminator_module)
        if single_gpu_flag(opt):
            save_checkpoint(
                {
                    "generator": model_module,
                    "discriminator": discriminator_module
                }, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        train_residual_old(opt, train_loader, model, model_module, gmm_model,
                           generator_model, embedder_model, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))