Example #1
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt, 1)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'HPM':
        model = UnetGenerator(25,
                              16,
                              6,
                              ngf=64,
                              norm_layer=nn.InstanceNorm2d,
                              clsf=True)
        d_g = Discriminator_G(opt, 16)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            if not os.path.isdir(opt.checkpoint):
                raise NotImplementedError(
                    'checkpoint should be dir, not file: %s' % opt.checkpoint)
            load_checkpoints(model, d_g, os.path.join(opt.checkpoint,
                                                      "%s.pth"))
        train_hpm(opt, train_loader, model, d_g, board)
        save_checkpoints(
            model, d_g,
            os.path.join(opt.checkpoint_dir,
                         opt.stage + '_' + opt.name + "_final", '%s.pth'))
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 3, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #2
0
def main():
    opt = get_opt()

    if opt.mode == 'test':
        opt.datamode  = "test"
        opt.data_list = "test_pairs.txt"
        opt.shuffle = False
    elif opt.mode == 'val':
        opt.shuffle = False
    elif opt.mode != 'train':
        print(opt.mode)

    print(opt)

    if opt.mode != 'train':
        opt.batch_size = 1


    if opt.mode != 'train' and not opt.checkpoint:
        print("You need to have a checkpoint for: "+opt.mode)
        return None

    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))
   
    # create dataset 
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))
    
    # create model & train & save the final checkpoint
    if opt.stage == 'HPM':
        model = UnetGenerator(25, 16, 6, ngf=64, norm_layer=nn.InstanceNorm2d, clsf=True)
        d_g= Discriminator_G(opt, 16)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
            load_checkpoint(d_g, opt.checkpoint[:-9] + "dg.pth")

        if opt.mode == "train":
            train_hpm(opt, train_loader, model, d_g, board)
        else:
            test_hpm(opt, train_loader, model)

        save_checkpoints(model, d_g, os.path.join(opt.checkpoint_dir, opt.stage +'_'+ opt.name+"_final", '%s.pth'))

    elif opt.stage == 'GMM':
        #seg_unet = UnetGenerator(25, 16, 6, ngf=64, norm_layer=nn.InstanceNorm2d, clsf=True)
        model = GMM(opt, 1)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        
        if opt.mode == "train":
            train_gmm(opt, train_loader, model, board)
        else:
            test_gmm(opt, train_loader, model)
        
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    
    elif opt.stage == 'TOM':
        model = UnetGenerator(31, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        if opt.mode == "train":
            train_tom(opt, train_loader, model, board)
        else:
            test_tom(opt, train_loader, model)
        
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)
  
    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #3
0
def train_tom(opt, train_loader, model, d_g, d_l, board):
    model.cuda()
    model.train()
    d_g.cuda()
    d_g.train()
    d_l.cuda()
    d_l.train()

    #reverse label
    dis_label_G = Variable(torch.FloatTensor(opt.batch_size,
                                             1)).fill_(0.).cuda()
    dis_label_real = Variable(torch.FloatTensor(opt.batch_size,
                                                1)).fill_(0.).cuda()
    dis_label_fake = Variable(torch.FloatTensor(opt.batch_size,
                                                1)).fill_(1.).cuda()

    # criterion
    criterionL1 = nn.L1Loss()
    criterionVGG = VGGLoss()
    criterionMask = nn.L1Loss()
    criterionGAN = nn.BCELoss()  #MSE

    # optimizer
    optimizerG = torch.optim.Adam(model.parameters(),
                                  lr=opt.lr,
                                  betas=(0.5, 0.999))
    optimizerDG = torch.optim.Adam(d_g.parameters(),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizerDL = torch.optim.Adam(d_l.parameters(),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    schedulerG = torch.optim.lr_scheduler.LambdaLR(
        optimizerG,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))
    schedulerDG = torch.optim.lr_scheduler.LambdaLR(
        optimizerDG,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))
    schedulerDL = torch.optim.lr_scheduler.LambdaLR(
        optimizerDL,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()

        #dis_label_noise
        dis_label_noise = random.random() / 10
        dis_label_real = dis_label_real.data.fill_(0.0 +
                                                   random.random() * opt.noise)
        dis_label_fake = dis_label_fake.data.fill_(1.0 -
                                                   random.random() * opt.noise)

        #prep
        inputs = train_loader.next_batch()

        im = inputs['image'].cuda()  #sz=b*3*256*192
        im_pose = inputs['pose_image']
        im_h = inputs['head']
        shape = inputs['shape']

        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        batch_size = im.size(0)
        if batch_size != opt.batch_size: continue

        #D_real
        errDg_real = criterionGAN(d_g(torch.cat([agnostic, c, im], 1)),
                                  dis_label_real)

        #generate image
        outputs = model(torch.cat([agnostic, c], 1))
        p_rendered, m_composite = torch.split(outputs, 3, 1)
        p_rendered = torch.tanh(p_rendered)
        m_composite = torch.sigmoid(m_composite)
        p_tryon = c * m_composite + p_rendered * (1 - m_composite)

        real_crop, fake_crop = random_crop(im, p_tryon, opt.winsize)
        errDl_real = criterionGAN(d_l(real_crop), dis_label_real)

        #tom_train
        errGg_fake = criterionGAN(d_g(torch.cat([agnostic, c, p_tryon], 1)),
                                  dis_label_G)
        errGl_fake = criterionGAN(d_l(fake_crop), dis_label_G)

        loss_l1 = criterionL1(p_tryon, im)
        loss_vgg = criterionVGG(p_tryon, im)
        loss_mask = criterionMask(m_composite, cm)
        loss_GAN = (errGg_fake + errGl_fake * opt.alpha) / batch_size
        loss = loss_l1 + loss_vgg + loss_mask + loss_GAN

        #D_fake
        errDg_fake = criterionGAN(
            d_g(torch.cat([agnostic, c, p_tryon], 1).detach()), dis_label_fake)
        loss_Dg = (errDg_fake + errDg_real) / 2

        errDl_fake = criterionGAN(d_l(fake_crop.detach()), dis_label_fake)
        loss_Dl = (errDl_fake + errDl_real) / 2

        optimizerG.zero_grad()
        loss.backward()
        optimizerG.step()

        optimizerDL.zero_grad()
        loss_Dl.backward()
        optimizerDL.step()

        optimizerDG.zero_grad()
        loss_Dg.backward()
        optimizerDG.step()
        #tensorboradX
        visuals = [[im_h, shape, im_pose],
                   [c, cm * 2 - 1, m_composite * 2 - 1],
                   [p_rendered, p_tryon, im]]

        if (step + 1) % opt.display_count == 0:
            t = time.time() - iter_start_time

            loss_dict = {
                "TOT": loss.item(),
                "L1": loss_l1.item(),
                "VG": loss_vgg.item(),
                "Mk": loss_mask.item(),
                "G": loss_GAN.item(),
                "DG": loss_Dg.item(),
                "DL": loss_Dl.item()
            }
            print('step: %d|time: %.3f' % (step + 1, t), end="")

            sm_image(combine_images(im, p_tryon, real_crop, fake_crop),
                     "combined%d.jpg" % step, opt.debug)
            board_add_images(board, 'combine', visuals, step + 1)
            for k, v in loss_dict.items():
                print('|%s: %.3f' % (k, v), end="")
                board.add_scalar(k, v, step + 1)
            print()

        if (step + 1) % opt.save_count == 0:
            save_checkpoints(
                model, d_g, d_l,
                os.path.join(opt.checkpoint_dir, opt.stage + '_' + opt.name,
                             "step%06d" % step, '%s.pth'))
Example #4
0
def train_hpm(opt, train_loader, model, d_g, board):
    model.cuda()
    model.train()
    d_g.cuda()
    d_g.train()

    dis_label = Variable(torch.FloatTensor(opt.batch_size)).cuda()
    
    # criterion
    criterionMCE = nn.CrossEntropyLoss()#nn.BCEWithLogitsLoss()
    criterionGAN = nn.BCELoss()
    
    # optimizer
    optimizerG = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    schedulerG = torch.optim.lr_scheduler.LambdaLR(optimizerG, lr_lambda = lambda step: 1.0 -
            max(0, step - opt.keep_step) / float(opt.decay_step + 1))

    optimizerD = torch.optim.Adam(d_g.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    schedulerD = torch.optim.lr_scheduler.LambdaLR(optimizerD, lr_lambda = lambda step: 1.0 -
            max(0, step - opt.keep_step) / float(opt.decay_step + 1))
    
    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        #prep
        inputs = train_loader.next_batch()
            
        im = inputs['image'].cuda()#sz=b*3*256*192
        sem_gt = inputs['seg'].cuda()
        seg_enc = inputs['seg_enc'].cuda()
        im_h = inputs['head']
        shape = inputs['shape']

        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        batch_size = im.size(0)

        optimizerD.zero_grad()
        #D_real
        dis_label.data.fill_(1.)
        dis_g_output = d_g(seg_enc)
        
        errDg_real = criterionGAN(dis_g_output, dis_label)
        errDg_real.backward()

        #generate image
        segmentation = model(torch.cat([agnostic, c],1))

        #D_fake
        dis_label.data.fill_(0.)

        dis_g_output = d_g(segmentation.detach())
        errDg_fake = criterionGAN(dis_g_output, dis_label)
        errDg_fake.backward()
        optimizerD.step()


        #model_train
        optimizerG.zero_grad()
        dis_label.data.fill_(1.)
        dis_g_output = d_g(segmentation)
        errG_fake = criterionGAN(dis_g_output, dis_label)
        loss_mce = criterionMCE(segmentation, sem_gt)

        loss = loss_mce + errG_fake
        loss.backward()
        optimizerG.step()
                    
        if (step+1) % opt.display_count == 0:
            t = time.time() - iter_start_time
            
            loss_dict = {"TOT":loss.item(), "MCE":loss_mce.item(), "GAN":errG_fake.item(), 
                         "DG":((errDg_fake+errDg_real)/2).item()}
            print('step: %d|time: %.3f'%(step+1, t), end="")
            
            for k, v in loss_dict.items():
                print('|%s: %.3f'%(k, v), end="")
                board.add_scalar(k, v, step+1)
            print()
            
        if (step+1) % opt.save_count == 0:
            save_checkpoints(model, d_g, 
                os.path.join(opt.checkpoint_dir, opt.stage +'_'+ opt.name, "step%06d"%step))