def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    dataset = CPDataset(opt)

    # create dataloader
    loader = CPDataLoader(opt, dataset)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=opt.workers,
                                              pin_memory=True,
                                              sampler=None)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    gmm_model = GMM(opt)
    load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
    gmm_model.cuda()

    generator_model = UnetGenerator(25,
                                    4,
                                    6,
                                    ngf=64,
                                    norm_layer=nn.InstanceNorm2d)
    load_checkpoint(generator_model,
                    "checkpoints/tom_train_new_2/step_040000.pth")
    generator_model.cuda()

    embedder_model = Embedder()
    load_checkpoint(embedder_model,
                    "checkpoints/identity_train_64_dim/step_020000.pth")
    embedder_model = embedder_model.embedder_b.cuda()

    model = UNet(n_channels=4, n_classes=3)
    model.cuda()

    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
        load_checkpoint(model, opt.checkpoint)

    test_residual(opt, data_loader, model, gmm_model, generator_model)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Esempio n. 2
0
def main():
    opt = get_opt()
    print(opt)

    print('Loading dataset')
    dataset_train = TOMDataset(opt, mode='train', data_list='train_pairs.txt')
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=opt.batch_size,
                                  num_workers=opt.n_worker,
                                  shuffle=True)
    dataset_val = TOMDataset(opt,
                             mode='val',
                             data_list='val_pairs.txt',
                             train=False)
    dataloader_val = DataLoader(dataset_val,
                                batch_size=opt.batch_size,
                                num_workers=opt.n_worker,
                                shuffle=True)

    save_dir = os.path.join(opt.out_dir, opt.name)
    log_dir = os.path.join(opt.out_dir, 'log')
    dirs = [opt.out_dir, save_dir, os.path.join(save_dir, 'train'), log_dir]
    for d in dirs:
        mkdir(d)
    log_name = os.path.join(log_dir, opt.name + '.csv')
    with open(log_name, 'w') as f:
        f.write('epoch,train_loss,val_loss\n')

    print('Building TOM model')
    gen = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
    dis = NLayerDiscriminator(28,
                              ndf=64,
                              n_layers=6,
                              norm_layer=nn.InstanceNorm2d,
                              use_sigmoid=True)
    gen.cuda()
    dis.cuda()
    n_step = int(opt.n_epoch * len(dataset_train) / opt.batch_size)
    trainer = TOMTrainer(gen, dis, dataloader_train, dataloader_val,
                         opt.gpu_id, opt.log_freq, save_dir, n_step)

    print('Start training TOM')
    for epoch in tqdm(range(opt.n_epoch)):
        print('Epoch: {}'.format(epoch))
        loss = trainer.train(epoch)
        print('Train loss: {:.3f}'.format(loss))
        with open(log_name, 'a') as f:
            f.write('{},{:.3f},'.format(epoch, loss))
        save_checkpoint(
            gen, os.path.join(save_dir, 'gen_epoch_{:02}.pth'.format(epoch)))
        save_checkpoint(
            dis, os.path.join(save_dir, 'dis_epoch_{:02}.pth'.format(epoch)))

        loss = trainer.val(epoch)
        print('Validation loss: {:.3f}'.format(loss))
        with open(log_name, 'a') as f:
            f.write('{:.3f}\n'.format(loss))
    print('Finish training TOM')
Esempio n. 3
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt, 1)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'HPM':
        model = UnetGenerator(25,
                              16,
                              6,
                              ngf=64,
                              norm_layer=nn.InstanceNorm2d,
                              clsf=True)
        d_g = Discriminator_G(opt, 16)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            if not os.path.isdir(opt.checkpoint):
                raise NotImplementedError(
                    'checkpoint should be dir, not file: %s' % opt.checkpoint)
            load_checkpoints(model, d_g, os.path.join(opt.checkpoint,
                                                      "%s.pth"))
        train_hpm(opt, train_loader, model, d_g, board)
        save_checkpoints(
            model, d_g,
            os.path.join(opt.checkpoint_dir,
                         opt.stage + '_' + opt.name + "_final", '%s.pth'))
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 3, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Esempio n. 4
0
 def __init__(self, gmm_path, tom_path):
     '''
     传入两个模型的预训练数据
     '''
     self.gmm = GMM()
     load_checkpoint(self.gmm, gmm_path)
     self.gmm.eval()
     self.tom = UnetGenerator(23, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
     load_checkpoint(self.tom, tom_path)
     self.tom.eval()
     self.gmm.cuda()
     self.tom.cuda()
Esempio n. 5
0
def main():
	opt = get_opt()
	print(opt)

	model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
	load_checkpoint(model, opt.checkpoint)
	#model.cuda()
	model.eval()

	mode = 'test'
	print('Run on {} data'.format(mode.upper()))
	dataset = TOMDataset(opt, mode, data_list=mode+'_pairs.txt', train=False)
	dataloader = DataLoader(dataset, batch_size=opt.batch_size, num_workers=opt.n_worker, shuffle=False)   
	with torch.no_grad():
		run(opt, model, dataloader, mode)
	print('Successfully completed')
Esempio n. 6
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to test stage: %s, named: %s!" % (opt.stage, opt.name))
   
    # create dataset 
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))
   
    # create model & train
    if opt.stage == 'GMM':
        model = GMM(opt)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_gmm(opt, train_loader, model, board)
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_tom(opt, train_loader, model, board)
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)
  
    print('Finished test %s, named: %s!' % (opt.stage, opt.name))
Esempio n. 7
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_tom(opt, train_loader, model, board)
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)


    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
def inference(opt):
    #opt = get_opt()
    #print(opt)
    print("Start to test stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # create model & train
    if opt.stage == 'GMM':
        model = GMM(opt)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_gmm(opt, train_loader, model)
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_tom(opt, train_loader, model)
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished test %s, named: %s!' % (opt.stage, opt.name))
Esempio n. 9
0
    def __init__(self, gmm_path, tom_path, use_cuda=False):

        self.use_cuda = use_cuda
        self.gmm = GMM(use_cuda=use_cuda)
        load_checkpoint(self.gmm, gmm_path, use_cuda=use_cuda)
        self.gmm.eval()
        self.tom = UnetGenerator(23,
                                 4,
                                 6,
                                 ngf=64,
                                 norm_layer=nn.InstanceNorm2d)
        load_checkpoint(self.tom, tom_path, use_cuda=use_cuda)
        self.tom.eval()
        if use_cuda:
            self.gmm.cuda()
            self.tom.cuda()
        print("use_cuda = " + str(self.use_cuda))
 def __init__(self, gmm_path, tom_path, use_cuda=True):
     '''
     初始化两个模型的预训练数据
     init pretrained models
     '''
     self.use_cuda = use_cuda
     self.gmm = GMM(use_cuda=use_cuda)
     load_checkpoint(self.gmm, gmm_path, use_cuda=use_cuda)
     self.gmm.eval()
     self.tom = UnetGenerator(23,
                              4,
                              6,
                              ngf=64,
                              norm_layer=nn.InstanceNorm2d)
     load_checkpoint(self.tom, tom_path, use_cuda=use_cuda)
     self.tom.eval()
     if use_cuda:
         self.gmm.cuda()
         self.tom.cuda()
     print("use_cuda = " + str(self.use_cuda))
Esempio n. 11
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to test stage: %s, named: %s!" % (opt.stage, opt.name))
    # create model & test
    model = UnetGenerator(26, 4, 6, ngf=64,
                          norm_layer=nn.InstanceNorm2d)  # CP-VTON+
    load_checkpoint(model, opt.checkpoint)
    with torch.no_grad():
        test_tom(opt, model, prepare_inputs(opt))

    print("Finished test %s, named: %s!" % (opt.stage, opt.name))
Esempio n. 12
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    model = UnetGenerator(22, 1, 6, ngf=64, norm_layer=nn.InstanceNorm2d)

    load_checkpoint(model, opt.checkpoint)
    with torch.no_grad():
        test_mask_gen(opt, train_loader, model)
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = get_dataset_class(opt.dataset)(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    board = None
    if opt.tensorboard_dir and not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == "GMM":
        model = GMM(opt)
        model.opt = opt
        if not opt.checkpoint == "" and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)

        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, "gmm_final.pth"))
    elif opt.stage == "TOM":
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.opt = opt
        if not opt.checkpoint == "" and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)

        train_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, "tom_final.pth"))
    else:
        raise NotImplementedError("Model [%s] is not implemented" % opt.stage)

    print("Finished training %s, nameed: %s!" % (opt.stage, opt.name))
Esempio n. 14
0
def CPVTON_wrapper(name="GMM", stage="GMM", workers=1, checkpoint=""):
    opt = get_opt()
    opt.name=name
    opt.stage=stage
    opt.workers=workers
    opt.checkpoint=checkpoint

    print(opt)
    print("Start to test stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    test_dataset = CPDataset(opt)

    # create dataloader
    test_loader = CPDataLoader(opt, test_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(logdir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & test
    if stage == 'GMM':
        model = GMM(opt)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_gmm(opt, test_loader, model, board)
    elif stage == 'TOM':
        # model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)  # CP-VTON
        model = UnetGenerator(26, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)  # CP-VTON+
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_tom(opt, test_loader, model, board)
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished test %s, named: %s!' % (opt.stage, opt.name))
Esempio n. 15
0
class CPVTON(object):
    def __init__(self, gmm_path, tom_path):
        '''
        初始化两个模型的预训练数据
        init pretrained models
        '''
        self.gmm = GMM()
        load_checkpoint(self.gmm, gmm_path)
        self.gmm.eval()
        self.tom = UnetGenerator(23,
                                 4,
                                 6,
                                 ngf=64,
                                 norm_layer=nn.InstanceNorm2d)
        load_checkpoint(self.tom, tom_path)
        self.tom.eval()
        self.gmm.cuda()
        self.tom.cuda()

    def predict(self, parse_array, pose_map, human, c):
        '''
        传入的前四个都是array. shape为(*,256,192)
        input 4 np array with the shape of (*,256,192)
        '''
        im = transformer(human)
        c = transformer(c)  # [-1,1]

        # parse -> shape

        parse_shape = (parse_array > 0).astype(np.float32)

        # 模糊化,下采样+上采样
        # blur, downsample + upsample
        parse_shape = Image.fromarray((parse_shape * 255).astype(np.uint8))
        parse_shape = parse_shape.resize((192 // 16, 256 // 16),
                                         Image.BILINEAR)
        parse_shape = parse_shape.resize((192, 256), Image.BILINEAR)
        shape = transformer(parse_shape)

        parse_head = (parse_array == 1).astype(np.float32) + \
            (parse_array == 2).astype(np.float32) + \
            (parse_array == 4).astype(np.float32) + \
            (parse_array == 13).astype(np.float32) + \
            (parse_array == 9).astype(np.float32)
        phead = torch.from_numpy(parse_head)  # [0,1]
        im_h = im * phead - (1 - phead)

        agnostic = torch.cat([shape, im_h, pose_map], 0)

        # batch==1
        agnostic = agnostic.unsqueeze(0).cuda()
        c = c.unsqueeze(0).cuda()

        # warp result
        grid, theta = self.gmm(agnostic.cuda(), c.cuda())
        c_warp = F.grid_sample(c.cuda(), grid, padding_mode='border')

        tensor = (c_warp.detach().clone() + 1) * 0.5 * 255
        tensor = tensor.cpu().clamp(0, 255)
        array = tensor.numpy().astype('uint8')

        c_warp = transformer(np.transpose(array[0], axes=(1, 2, 0)))
        c_warp = c_warp.unsqueeze(0)

        outputs = self.tom(torch.cat([agnostic.cuda(), c_warp.cuda()], 1))
        p_rendered, m_composite = torch.split(outputs, 3, 1)
        p_rendered = torch.tanh(p_rendered)
        m_composite = torch.sigmoid(m_composite)
        p_tryon = c_warp.cuda() * m_composite + p_rendered * (1 - m_composite)

        return (p_tryon, c_warp)
Esempio n. 16
0
def main():
    opt = get_opt()

    if opt.mode == 'test':
        opt.datamode  = "test"
        opt.data_list = "test_pairs.txt"
        opt.shuffle = False
    elif opt.mode == 'val':
        opt.shuffle = False
    elif opt.mode != 'train':
        print(opt.mode)

    print(opt)

    if opt.mode != 'train':
        opt.batch_size = 1


    if opt.mode != 'train' and not opt.checkpoint:
        print("You need to have a checkpoint for: "+opt.mode)
        return None

    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))
   
    # create dataset 
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))
    
    # create model & train & save the final checkpoint
    if opt.stage == 'HPM':
        model = UnetGenerator(25, 16, 6, ngf=64, norm_layer=nn.InstanceNorm2d, clsf=True)
        d_g= Discriminator_G(opt, 16)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
            load_checkpoint(d_g, opt.checkpoint[:-9] + "dg.pth")

        if opt.mode == "train":
            train_hpm(opt, train_loader, model, d_g, board)
        else:
            test_hpm(opt, train_loader, model)

        save_checkpoints(model, d_g, os.path.join(opt.checkpoint_dir, opt.stage +'_'+ opt.name+"_final", '%s.pth'))

    elif opt.stage == 'GMM':
        #seg_unet = UnetGenerator(25, 16, 6, ngf=64, norm_layer=nn.InstanceNorm2d, clsf=True)
        model = GMM(opt, 1)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        
        if opt.mode == "train":
            train_gmm(opt, train_loader, model, board)
        else:
            test_gmm(opt, train_loader, model)
        
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    
    elif opt.stage == 'TOM':
        model = UnetGenerator(31, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        if opt.mode == "train":
            train_tom(opt, train_loader, model, board)
        else:
            test_tom(opt, train_loader, model)
        
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)
  
    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        if opt.model == 'RefinedGMM':
            model = RefinedGMM(opt)
        elif opt.model == 'OneRefinedGMM':
            model = OneRefinedGMM(opt)
        else:
            raise TypeError()
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_refined_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model,
            os.path.join(opt.checkpoint_dir, opt.name,
                         'refined_gmm_final.pth'))
    elif opt.stage == 'VariGMM':
        model = VariGMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_refined_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model,
            os.path.join(opt.checkpoint_dir, opt.name,
                         'refined_gmm_final.pth'))
    elif opt.stage == 'semanticGMM':
        model = RefinedGMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_semantic_parsing_refined_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model,
            os.path.join(opt.checkpoint_dir, opt.name,
                         'refined_gmm_final.pth'))
    elif opt.stage == 'no_background_GMM':
        model = RefinedGMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_no_background_refined_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model,
            os.path.join(opt.checkpoint_dir, opt.name,
                         'no_background_refined_gmm_final.pth'))
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    elif opt.stage == 'DeepTom':
        norm_layer = 'instance'
        use_dropout = True
        with_tanh = False
        model = Define_G(25,
                         4,
                         64,
                         'treeresnet',
                         'instance',
                         True,
                         'normal',
                         0.02,
                         opt.gpu_ids,
                         with_tanh=False)
        train_deep_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Esempio n. 18
0
        test_sampler = sampler.RandomSampler(test_dataset)
    test_loader = DataLoader(test_dataset,
                             batch_size=opt.batch_size,
                             shuffle=opt.shuffle,
                             num_workers=opt.workers,
                             pin_memory=True,
                             sampler=test_sampler)

    # visualization
    os.makedirs(opt.tensorboard_dir, exist_ok=True)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name,
                                               opt.data_list.split('.')[0]))

    # create model & train
    if opt.stage == 'GMM':
        print('Dataset size: %05d!' % (len(test_dataset)), flush=True)
        model = GMM(opt)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_gmm(opt, test_loader, model, board)
    elif opt.stage == 'TOM':
        print('Dataset size: %05d!' % (len(test_dataset)), flush=True)
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_tom(opt, test_loader, model, board)
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished test %s, named: %s!' % (opt.stage, opt.name))
Esempio n. 19
0
def test_Unet_size():
    from networks import UnetGenerator
    model = UnetGenerator(input_nc=3, output_nc=3, num_downs=7, ngf=64)
    x = Variable(torch.rand(1, 3, 224, 224))
    y = model(x)
Esempio n. 20
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    gmm_model = GMM(opt)
    load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
    gmm_model.cuda()

    generator_model = UnetGenerator(25,
                                    4,
                                    6,
                                    ngf=64,
                                    norm_layer=nn.InstanceNorm2d)
    load_checkpoint(generator_model,
                    "checkpoints/tom_train_new_2/step_040000.pth")
    generator_model.cuda()

    embedder_model = Embedder()
    load_checkpoint(embedder_model,
                    "checkpoints/identity_train_64_dim/step_020000.pth")
    embedder_model = embedder_model.embedder_b.cuda()

    model = G()
    model.apply(utils.weights_init('kaiming'))
    model.cuda()

    if opt.use_gan:
        discriminator = Discriminator()
        discriminator.apply(utils.weights_init('gaussian'))
        discriminator.cuda()

    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
        load_checkpoint(model, opt.checkpoint)

    model_module = model
    if opt.use_gan:
        discriminator_module = discriminator
    if opt.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True)
        model_module = model.module
        if opt.use_gan:
            discriminator = torch.nn.parallel.DistributedDataParallel(
                discriminator,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            discriminator_module = discriminator.module

    if opt.use_gan:
        train_residual_old(opt,
                           train_loader,
                           model,
                           model_module,
                           gmm_model,
                           generator_model,
                           embedder_model,
                           board,
                           discriminator=discriminator,
                           discriminator_module=discriminator_module)
        if single_gpu_flag(opt):
            save_checkpoint(
                {
                    "generator": model_module,
                    "discriminator": discriminator_module
                }, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        train_residual_old(opt, train_loader, model, model_module, gmm_model,
                           generator_model, embedder_model, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Esempio n. 21
0
# create dataset
train_dataset = CPDataset(opt)

# create dataloader
train_loader = CPDataLoader(opt, train_dataset)
data_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=opt.batch_size,
                                          shuffle=False,
                                          num_workers=opt.workers,
                                          pin_memory=True)

gmm_model = GMM(opt)
load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
gmm_model.cuda()

generator = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
load_checkpoint(generator, "checkpoints/tom_train_new_2/step_070000.pth")
generator.cuda()

embedder_model = Embedder()
load_checkpoint(embedder_model,
                "checkpoints/identity_embedding_for_test/step_045000.pth")
image_embedder = embedder_model.embedder_b.cuda()
prod_embedder = embedder_model.embedder_a.cuda()

model = G()
if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
    load_checkpoint(model, opt.checkpoint)
model.cuda()

model.eval()
Esempio n. 22
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'TOM':

        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.cuda()
        # if opt.distributed:
        #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module

        train_tom(opt, train_loader, model, model_module, gmm_model, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    elif opt.stage == 'TOM+WARP':

        gmm_model = GMM(opt)
        gmm_model.cuda()

        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.cuda()
        # if opt.distributed:
        #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        gmm_model_module = gmm_model
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            gmm_model = torch.nn.parallel.DistributedDataParallel(
                gmm_model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            gmm_model_module = gmm_model.module

        train_tom_gmm(opt, train_loader, model, model_module, gmm_model,
                      gmm_model_module, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))

    elif opt.stage == "identity":
        model = Embedder()
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_identity_embedding(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'residual':

        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        generator_model = UnetGenerator(25,
                                        4,
                                        6,
                                        ngf=64,
                                        norm_layer=nn.InstanceNorm2d)
        load_checkpoint(generator_model,
                        "checkpoints/tom_train_new/step_038000.pth")
        generator_model.cuda()

        embedder_model = Embedder()
        load_checkpoint(embedder_model,
                        "checkpoints/identity_train_64_dim/step_020000.pth")
        embedder_model = embedder_model.embedder_b.cuda()

        model = UNet(n_channels=4, n_classes=3)
        if opt.distributed:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.apply(utils.weights_init('kaiming'))
        model.cuda()

        if opt.use_gan:
            discriminator = Discriminator()
            discriminator.apply(utils.weights_init('gaussian'))
            discriminator.cuda()

            acc_discriminator = AccDiscriminator()
            acc_discriminator.apply(utils.weights_init('gaussian'))
            acc_discriminator.cuda()

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
            if opt.use_gan:
                load_checkpoint(discriminator,
                                opt.checkpoint.replace("step_", "step_disc_"))

        model_module = model
        if opt.use_gan:
            discriminator_module = discriminator
            acc_discriminator_module = acc_discriminator

        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            if opt.use_gan:
                discriminator = torch.nn.parallel.DistributedDataParallel(
                    discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                discriminator_module = discriminator.module

                acc_discriminator = torch.nn.parallel.DistributedDataParallel(
                    acc_discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                acc_discriminator_module = acc_discriminator.module

        if opt.use_gan:
            train_residual(opt,
                           train_loader,
                           model,
                           model_module,
                           gmm_model,
                           generator_model,
                           embedder_model,
                           board,
                           discriminator=discriminator,
                           discriminator_module=discriminator_module,
                           acc_discriminator=acc_discriminator,
                           acc_discriminator_module=acc_discriminator_module)

            if single_gpu_flag(opt):
                save_checkpoint(
                    {
                        "generator": model_module,
                        "discriminator": discriminator_module
                    },
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
        else:
            train_residual(opt, train_loader, model, model_module, gmm_model,
                           generator_model, embedder_model, board)
            if single_gpu_flag(opt):
                save_checkpoint(
                    model_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
    elif opt.stage == "residual_old":
        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        generator_model = UnetGenerator(25,
                                        4,
                                        6,
                                        ngf=64,
                                        norm_layer=nn.InstanceNorm2d)
        load_checkpoint(generator_model,
                        "checkpoints/tom_train_new_2/step_070000.pth")
        generator_model.cuda()

        embedder_model = Embedder()
        load_checkpoint(embedder_model,
                        "checkpoints/identity_train_64_dim/step_020000.pth")
        embedder_model = embedder_model.embedder_b.cuda()

        model = UNet(n_channels=4, n_classes=3)
        if opt.distributed:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.apply(utils.weights_init('kaiming'))
        model.cuda()

        if opt.use_gan:
            discriminator = Discriminator()
            discriminator.apply(utils.weights_init('gaussian'))
            discriminator.cuda()

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        if opt.use_gan:
            discriminator_module = discriminator
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            if opt.use_gan:
                discriminator = torch.nn.parallel.DistributedDataParallel(
                    discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                discriminator_module = discriminator.module

        if opt.use_gan:
            train_residual_old(opt,
                               train_loader,
                               model,
                               model_module,
                               gmm_model,
                               generator_model,
                               embedder_model,
                               board,
                               discriminator=discriminator,
                               discriminator_module=discriminator_module)
            if single_gpu_flag(opt):
                save_checkpoint(
                    {
                        "generator": model_module,
                        "discriminator": discriminator_module
                    },
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
        else:
            train_residual_old(opt, train_loader, model, model_module,
                               gmm_model, generator_model, embedder_model,
                               board)
            if single_gpu_flag(opt):
                save_checkpoint(
                    model_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))