def test_tom(opt, test_loader, model, board):
    print('----Testing of module {} started----'.format(opt.name))
    model.to(device)
    model.eval()

    unet_mask = UnetGenerator(25, 20, ngf=64)
    load_checkpoint(unet_mask,
                    os.path.join(opt.checkpoint_dir, 'SEG', 'segm_final.pth'))
    unet_mask.to(device)
    unet_mask.eval()

    gmm = GMM(opt)
    load_checkpoint(gmm,
                    os.path.join(opt.checkpoint_dir, 'GMM', 'gmm_final.pth'))
    gmm.to(device)
    gmm.eval()

    length = len(test_loader.data_loader)
    step = 0
    pbar = tqdm(total=length)

    inputs = test_loader.next_batch()
    while inputs is not None:
        im_name = inputs['im_name']
        im_h = inputs['head'].to(device)
        im = inputs['image'].to(device)
        agnostic = inputs['agnostic'].to(device)
        c = inputs['cloth'].to(device)
        # c_warp = inputs['cloth_warp'].to(device)
        im_c = inputs['parse_cloth'].to(device)
        im_c_mask = inputs['parse_cloth_mask'].to(device)
        im_ttp = inputs['texture_t_prior'].to(device)

        with torch.no_grad():
            output_segm = unet_mask(torch.cat([agnostic, c], 1))
            grid_zero, theta, grid_one, delta_theta = gmm(agnostic, c)
        c_warp = F.grid_sample(c, grid_one, padding_mode='border')
        output_segm = F.log_softmax(output_segm, dim=1)

        output_argm = torch.max(output_segm, dim=1, keepdim=True)[1]
        final_segm = torch.zeros(output_segm.shape).to(device).scatter(
            1, output_argm, 1.0)
        input_tom = torch.cat([final_segm, c_warp, im_ttp], 1)

        with torch.no_grad():
            output_tom = model(input_tom)
        person_r = torch.tanh(output_tom[:, :3, :, :])
        mask_c = torch.sigmoid(output_tom[:, 3:, :, :])
        mask_c = (mask_c >= 0.5).type(torch.float)
        img_tryon = mask_c * c_warp + (1 - mask_c) * person_r

        visuals = [[im, c, img_tryon], [im_c, c_warp, person_r],
                   [im_c_mask, mask_c, im_h]]
        board_add_images(board, 'combine', visuals, step + 1)
        save_images(img_tryon, im_name,
                    osp.join(opt.dataroot, opt.datamode, 'final-output'))

        inputs = test_loader.next_batch()
        step += 1
        pbar.update(1)
def main():
    opt = parser()

    test_dataset = SieveDataset(opt)
    # create dataloader
    test_loader = SieveDataLoader(opt, test_dataset)

    if opt.name == 'GMM':
        model = GMM(opt)

        # visualization
        if not os.path.exists(
                os.path.join(opt.tensorboard_dir, opt.name, opt.datamode)):
            os.makedirs(
                os.path.join(opt.tensorboard_dir, opt.name, opt.datamode))
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name, opt.datamode))

        checkpoint_path = osp.join(opt.checkpoint_dir, opt.name,
                                   'gmm_final.pth')
        load_checkpoint(model, checkpoint_path)
        test_gmm(opt, test_loader, model, board)

    elif opt.name == 'TOM':
        model = UnetGenerator(26, 4, ngf=64)

        # visualization
        if not os.path.exists(
                os.path.join(opt.tensorboard_dir, opt.name, opt.datamode)):
            os.makedirs(
                os.path.join(opt.tensorboard_dir, opt.name, opt.datamode))
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name, opt.datamode))

        checkpoint_path = osp.join(opt.checkpoint_dir, opt.name,
                                   'tom_final.pth')
        load_checkpoint(model, checkpoint_path)
        test_tom(opt, test_loader, model, board)
示例#3
0
def main():
    opt = parser()

    train_dataset = SieveDataset(opt)
    # create dataloader
    train_loader = SieveDataLoader(opt, train_dataset)
   
    # create model & train & save the final checkpoint
    if opt.name == 'GMM':
        model = GMM(opt)
        
        # visualization
        if not os.path.exists(os.path.join(opt.tensorboard_dir, opt.name)):
            os.makedirs(os.path.join(opt.tensorboard_dir, opt.name))
        board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))
    
        if not os.path.exists(os.path.join(opt.checkpoint_dir, opt.name)):
            os.makedirs(os.path.join(opt.checkpoint_dir, opt.name))
            
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
        
    elif opt.name == 'SEG':
        
        #input channel = agnostic(22) + cloth(3)
        #output channel = segemantion output(20)
        model = UnetGenerator(25, 20, ngf=64)

        # visualization
        if not os.path.exists(os.path.join(opt.tensorboard_dir, opt.name)):
            os.makedirs(os.path.join(opt.tensorboard_dir, opt.name))
        board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))

        #for checkpoint saving
        if not os.path.exists(os.path.join(opt.checkpoint_dir, opt.name)):
            os.makedirs(os.path.join(opt.checkpoint_dir, opt.name))

        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
            print('Checkpoints loaded!')
        train_segm(opt, train_loader, model, board)
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'segm_final.pth'))


    elif opt.name == 'TOM':

        #input channel = generated seg mask(20) + texture translation prior(3) + warped cloth(3)
        #output channel = cloth mask(1) + rendered person(3)
        model = UnetGenerator(26, 4, ngf=64)
        
        #for  Duelling Triplet Loss Strategy
        model_triloss = UnetGenerator(26, 4, ngf=64)

        # visualization
        if not os.path.exists(os.path.join(opt.tensorboard_dir, opt.name)):
            os.makedirs(os.path.join(opt.tensorboard_dir, opt.name))
        board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))

        if not os.path.exists(os.path.join(opt.checkpoint_dir, opt.name)):
            os.makedirs(os.path.join(opt.checkpoint_dir, opt.name))

        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_tom(opt, train_loader, model, model_triloss, board)
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
示例#4
0
def main():
    opt = parser()

    im_path = opt.input_image_path                                               #person image path
    cloth_path =  opt.cloth_image_path                                           #cloth image path
    pose_path = opt.input_image_path.replace('.jpg', '_keypoints.json')          #pose keypoint path
    generate_pose_keypoints(im_path)                                             #generating pose keypoints
    segm_path = opt.human_parsing_image_path                                     #segemented mask path
    img_name = im_path.split('/')[-1].split('.')[0] + '_' 

    agnostic, c, im_ttp = generate_data(opt, im_path, cloth_path, pose_path, segm_path)
    
    agnostic = agnostic.to(device)
    c = c.to(device)
    im_ttp = im_ttp.to(device)

    gmm = GMM(opt)
    load_checkpoint(gmm, os.path.join(opt.checkpoint_dir, 'GMM', 'gmm_final.pth'))
    gmm.to(device)
    gmm.eval()

    unet_mask = UnetGenerator(25, 20, ngf=64)
    load_checkpoint(unet_mask, os.path.join(opt.checkpoint_dir, 'SEG', 'segm_final.pth'))
    unet_mask.to(device)
    unet_mask.eval()

    tom = UnetGenerator(26, 4, ngf=64)
    load_checkpoint(tom, os.path.join(opt.checkpoint_dir, 'TOM', 'tom_final.pth'))
    tom.to(device)
    tom.eval()

    with torch.no_grad():
        output_segm = unet_mask(torch.cat([agnostic, c], 1))
        grid_zero, theta, grid_one, delta_theta = gmm(agnostic, c)
    c_warp = F.grid_sample(c, grid_one, padding_mode='border')
    output_segm = F.log_softmax(output_segm, dim=1)

    output_argm = torch.max(output_segm, dim=1, keepdim=True)[1]
    final_segm = torch.zeros(output_segm.shape).to(device).scatter(1, output_argm, 1.0)

    input_tom = torch.cat([final_segm, c_warp, im_ttp], 1)

    with torch.no_grad():
        output_tom = tom(input_tom)
    person_r = torch.tanh(output_tom[:,:3,:,:])
    mask_c = torch.sigmoid(output_tom[:,3:,:,:])
    mask_c = (mask_c >= 0.5).type(torch.float)
    img_tryon = mask_c * c_warp + (1 - mask_c) * person_r
    print('Output generated!')

    c_warp = c_warp*0.5+0.5
    output_argm = output_argm.type(torch.float)
    person_r = person_r*0.5+0.5
    img_tryon = img_tryon*0.5+0.5

    tensortoimage(c_warp[0].cpu(), osp.join(opt.save_dir, img_name+'w_cloth.png'))
    tensortoimage(output_argm[0][0].cpu(), osp.join(opt.save_dir, img_name+'seg_mask.png'))
    tensortoimage(mask_c[0].cpu(), osp.join(opt.save_dir, img_name+'c_mask.png'))
    tensortoimage(person_r[0].cpu(), osp.join(opt.save_dir, img_name+'ren_person.png'))
    tensortoimage(img_tryon[0].cpu(), osp.join(opt.save_dir, img_name+'final_output.png'))
    print('Output saved at {}'.format(opt.save_dir))
示例#5
0
def train_tom(opt, train_loader, model, model_triloss, board):

    print('----Traning of module {} started----'.format(opt.name))
    model.to(device)
    model.train()

    unet_mask = UnetGenerator(25, 20, ngf=64)
    load_checkpoint(unet_mask, os.path.join(opt.checkpoint_dir, 'SEG', 'segm_final.pth'))
    unet_mask.to(device)
    unet_mask.eval()

    gmm = GMM(opt)
    load_checkpoint(gmm, os.path.join(opt.checkpoint_dir, 'GMM', 'gmm_final.pth'))
    gmm.to(device)
    gmm.eval()

    #scheduler is not used during training
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    # lr_lambda = lambda step: 1.0 - max(0, step - opt.keep_step) / float(opt.decay_step + 1)
    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lr_lambda)
    
    #for updating network parameter used for triplet loss
    update_checkpoint = None

    for step in range(opt.previous_step, opt.keep_step + opt.decay_step):

        iter_start_time = time.time()
        inputs = train_loader.next_batch()
        
        im_h = inputs['head'].to(device)    
        im = inputs['image'].to(device)
        agnostic = inputs['agnostic'].to(device)
        c = inputs['cloth'].to(device)
        im_c =  inputs['parse_cloth'].to(device)
        im_c_mask = inputs['parse_cloth_mask'].to(device)
        im_ttp = inputs['texture_t_prior'].to(device)


        with torch.no_grad():
            output_segm = unet_mask(torch.cat([agnostic, c], 1))
            grid_zero, theta, grid_one, delta_theta = gmm(agnostic, c)
        c_warp = F.grid_sample(c, grid_one, padding_mode='border')
        output_segm = F.log_softmax(output_segm, dim=1)

        output_argm = torch.max(output_segm, dim=1, keepdim=True)[1]
        final_segm = torch.zeros(output_segm.shape).to(device).scatter(1, output_argm, 1.0)

        input_tom = torch.cat([final_segm, c_warp, im_ttp], 1)

        output_tom = model(input_tom)
        person_r = torch.tanh(output_tom[:,:3,:,:])
        mask_c = torch.sigmoid(output_tom[:,3:,:,:])
        img_tryon = mask_c * c_warp + (1 - mask_c) * person_r

        #if steps greater than k then adding triplet to main tom loss
        if step >= opt.keep_step:
            model_triloss.to(device)
            model_triloss.eval()
        
        if step > opt.keep_step:
            i_prev = int(step/opt.save_count)*opt.save_count

            if update_checkpoint!=i_prev:
                path = osp.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (i_prev))
                load_checkpoint(model_triloss, path)
                update_checkpoint = i_prev
            
            with torch.no_grad():
                output_tom_tri = model_triloss(input_tom)
            person_r_tri = torch.tanh(output_tom_tri[:,:3,:,:])
            mask_c_tri = torch.sigmoid(output_tom_tri[:,3:,:,:])
            img_tryon_tri = mask_c_tri * c_warp + (1 - mask_c_tri) * person_r_tri
            loss = tom_loss(img_tryon, mask_c, im, im_c_mask, img_tryon_tri)
        else:
            loss = tom_loss(img_tryon, mask_c, im, im_c_mask)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        visuals = [[im, c, img_tryon], 
                    [im_c, c_warp, person_r],
                    [im_c_mask,mask_c, im_h]]

        if (step+1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step+1)
            board.add_scalar('loss', loss.item(), step+1)
            board.add_scalar('GPU Allocated', round(torch.cuda.memory_allocated(0)/1024**3,1), step+1)
            board.add_scalar('GPU Cached', round(torch.cuda.memory_cached(0)/1024**3,1), step+1)
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %4f' % (step+1, t, loss.item()), flush=True)

        if (step+1) % opt.save_count == 0:
            save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))