Example #1
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_tom(opt, train_loader, model, board)
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)


    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #2
0
def main():
    opt = get_opt()
    #     opt.cuda = False
    #     opt.batch_size = 1
    #     opt.name = "test"
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    G = WUTON(opt)
    D = Discriminator()
    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):  # TODO
        load_checkpoint(G, opt.checkpoint)
    train(opt, train_loader, G, D, board)
    # train2(opt, train_loader, G, board)
    save_checkpoint(
        G, os.path.join(opt.checkpoint_dir, opt.name, 'wuton_final.pth'))

    print('Finished training %s, named: %s!' % (opt.stage, opt.name))
def inference(opt):
    #opt = get_opt()
    #print(opt)
    print("Start to test stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # create model & train
    if opt.stage == 'GMM':
        model = GMM(opt)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_gmm(opt, train_loader, model)
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_tom(opt, train_loader, model)
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished test %s, named: %s!' % (opt.stage, opt.name))
Example #4
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to test stage: %s, named: %s!" % (opt.stage, opt.name))
   
    # create dataset 
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))
   
    # create model & train
    if opt.stage == 'GMM':
        model = GMM(opt)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_gmm(opt, train_loader, model, board)
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_tom(opt, train_loader, model, board)
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)
  
    print('Finished test %s, named: %s!' % (opt.stage, opt.name))
Example #5
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt, 1)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'HPM':
        model = UnetGenerator(25,
                              16,
                              6,
                              ngf=64,
                              norm_layer=nn.InstanceNorm2d,
                              clsf=True)
        d_g = Discriminator_G(opt, 16)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            if not os.path.isdir(opt.checkpoint):
                raise NotImplementedError(
                    'checkpoint should be dir, not file: %s' % opt.checkpoint)
            load_checkpoints(model, d_g, os.path.join(opt.checkpoint,
                                                      "%s.pth"))
        train_hpm(opt, train_loader, model, d_g, board)
        save_checkpoints(
            model, d_g,
            os.path.join(opt.checkpoint_dir,
                         opt.stage + '_' + opt.name + "_final", '%s.pth'))
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 3, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #6
0
def main():
    opt = get_opt()
    print(opt)

    # create model & test
    model = GMM(opt)
    load_checkpoint(model, opt.checkpoint)
    with torch.no_grad():
        inputs = prepare_inputs(opt)
        st = time.time()
        predict_gmm(opt, model, inputs, save_intermediate=True)
        print("time: ", time.time() - st)
Example #7
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to test stage: %s, named: %s!" % (opt.stage, opt.name))
    # create model & test
    model = UnetGenerator(26, 4, 6, ngf=64,
                          norm_layer=nn.InstanceNorm2d)  # CP-VTON+
    load_checkpoint(model, opt.checkpoint)
    with torch.no_grad():
        test_tom(opt, model, prepare_inputs(opt))

    print("Finished test %s, named: %s!" % (opt.stage, opt.name))
Example #8
0
 def __init__(self, gmm_path, tom_path):
     '''
     传入两个模型的预训练数据
     '''
     self.gmm = GMM()
     load_checkpoint(self.gmm, gmm_path)
     self.gmm.eval()
     self.tom = UnetGenerator(23, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
     load_checkpoint(self.tom, tom_path)
     self.tom.eval()
     self.gmm.cuda()
     self.tom.cuda()
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    dataset = CPDataset(opt)

    # create dataloader
    loader = CPDataLoader(opt, dataset)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=opt.workers,
                                              pin_memory=True,
                                              sampler=None)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    gmm_model = GMM(opt)
    load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
    gmm_model.cuda()

    generator_model = UnetGenerator(25,
                                    4,
                                    6,
                                    ngf=64,
                                    norm_layer=nn.InstanceNorm2d)
    load_checkpoint(generator_model,
                    "checkpoints/tom_train_new_2/step_040000.pth")
    generator_model.cuda()

    embedder_model = Embedder()
    load_checkpoint(embedder_model,
                    "checkpoints/identity_train_64_dim/step_020000.pth")
    embedder_model = embedder_model.embedder_b.cuda()

    model = UNet(n_channels=4, n_classes=3)
    model.cuda()

    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
        load_checkpoint(model, opt.checkpoint)

    test_residual(opt, data_loader, model, gmm_model, generator_model)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #10
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    model = UnetGenerator(22, 1, 6, ngf=64, norm_layer=nn.InstanceNorm2d)

    load_checkpoint(model, opt.checkpoint)
    with torch.no_grad():
        test_mask_gen(opt, train_loader, model)
Example #11
0
def main():
	opt = get_opt()
	print(opt)

	model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
	load_checkpoint(model, opt.checkpoint)
	#model.cuda()
	model.eval()

	mode = 'test'
	print('Run on {} data'.format(mode.upper()))
	dataset = TOMDataset(opt, mode, data_list=mode+'_pairs.txt', train=False)
	dataloader = DataLoader(dataset, batch_size=opt.batch_size, num_workers=opt.n_worker, shuffle=False)   
	with torch.no_grad():
		run(opt, model, dataloader, mode)
	print('Successfully completed')
Example #12
0
def main():
	opt = get_opt()
	print(opt)

	model = GMM(opt)
	load_checkpoint(model, opt.checkpoint)
	# model.cuda()
	model.eval()

	modes = ['train', 'val', 'test']
	for mode in modes:
		print('Run on {} data'.format(mode.upper()))
		dataset = GMMDataset(opt, mode, data_list=mode+'_pairs.txt', train=False)
		dataloader = DataLoader(dataset, batch_size=opt.batch_size, num_workers=opt.n_worker, shuffle=False)   
		with torch.no_grad():
			run(opt, model, dataloader, mode)
	print('Successfully completed')
Example #13
0
    def __init__(self, gmm_path, tom_path, use_cuda=False):

        self.use_cuda = use_cuda
        self.gmm = GMM(use_cuda=use_cuda)
        load_checkpoint(self.gmm, gmm_path, use_cuda=use_cuda)
        self.gmm.eval()
        self.tom = UnetGenerator(23,
                                 4,
                                 6,
                                 ngf=64,
                                 norm_layer=nn.InstanceNorm2d)
        load_checkpoint(self.tom, tom_path, use_cuda=use_cuda)
        self.tom.eval()
        if use_cuda:
            self.gmm.cuda()
            self.tom.cuda()
        print("use_cuda = " + str(self.use_cuda))
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = get_dataset_class(opt.dataset)(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    board = None
    if opt.tensorboard_dir and not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == "GMM":
        model = GMM(opt)
        model.opt = opt
        if not opt.checkpoint == "" and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)

        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, "gmm_final.pth"))
    elif opt.stage == "TOM":
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.opt = opt
        if not opt.checkpoint == "" and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)

        train_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, "tom_final.pth"))
    else:
        raise NotImplementedError("Model [%s] is not implemented" % opt.stage)

    print("Finished training %s, nameed: %s!" % (opt.stage, opt.name))
 def __init__(self, gmm_path, tom_path, use_cuda=True):
     '''
     初始化两个模型的预训练数据
     init pretrained models
     '''
     self.use_cuda = use_cuda
     self.gmm = GMM(use_cuda=use_cuda)
     load_checkpoint(self.gmm, gmm_path, use_cuda=use_cuda)
     self.gmm.eval()
     self.tom = UnetGenerator(23,
                              4,
                              6,
                              ngf=64,
                              norm_layer=nn.InstanceNorm2d)
     load_checkpoint(self.tom, tom_path, use_cuda=use_cuda)
     self.tom.eval()
     if use_cuda:
         self.gmm.cuda()
         self.tom.cuda()
     print("use_cuda = " + str(self.use_cuda))
Example #16
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)
    print('//////////////////////')

    # visualization
    # if not os.path.exists(opt.tensorboard_dir):
    #     os.makedirs(opt.tensorboard_dir)
    # # board = SummaryWriter(log_dir = 'G:/work 2/Codes/cp-vton-master/tensorboard')
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'PGP':
        model = PGP(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_pgp(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'pgp_final.pth'))
    elif opt.stage == 'GMM':
        model = GMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))

    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #17
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.batch_size,
                                               shuffle=False,
                                               num_workers=opt.workers,
                                               pin_memory=True)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    model = Embedder()
    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
        load_checkpoint(model, opt.checkpoint)
    test_identity_embedding(opt, train_loader, model, board)
    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #18
0
def CPVTON_wrapper(name="GMM", stage="GMM", workers=1, checkpoint=""):
    opt = get_opt()
    opt.name=name
    opt.stage=stage
    opt.workers=workers
    opt.checkpoint=checkpoint

    print(opt)
    print("Start to test stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    test_dataset = CPDataset(opt)

    # create dataloader
    test_loader = CPDataLoader(opt, test_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(logdir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & test
    if stage == 'GMM':
        model = GMM(opt)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_gmm(opt, test_loader, model, board)
    elif stage == 'TOM':
        # model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)  # CP-VTON
        model = UnetGenerator(26, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)  # CP-VTON+
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_tom(opt, test_loader, model, board)
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished test %s, named: %s!' % (opt.stage, opt.name))
Example #19
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'TOM':

        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.cuda()
        # if opt.distributed:
        #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module

        train_tom(opt, train_loader, model, model_module, gmm_model, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    elif opt.stage == 'TOM+WARP':

        gmm_model = GMM(opt)
        gmm_model.cuda()

        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.cuda()
        # if opt.distributed:
        #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        gmm_model_module = gmm_model
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            gmm_model = torch.nn.parallel.DistributedDataParallel(
                gmm_model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            gmm_model_module = gmm_model.module

        train_tom_gmm(opt, train_loader, model, model_module, gmm_model,
                      gmm_model_module, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))

    elif opt.stage == "identity":
        model = Embedder()
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_identity_embedding(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'residual':

        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        generator_model = UnetGenerator(25,
                                        4,
                                        6,
                                        ngf=64,
                                        norm_layer=nn.InstanceNorm2d)
        load_checkpoint(generator_model,
                        "checkpoints/tom_train_new/step_038000.pth")
        generator_model.cuda()

        embedder_model = Embedder()
        load_checkpoint(embedder_model,
                        "checkpoints/identity_train_64_dim/step_020000.pth")
        embedder_model = embedder_model.embedder_b.cuda()

        model = UNet(n_channels=4, n_classes=3)
        if opt.distributed:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.apply(utils.weights_init('kaiming'))
        model.cuda()

        if opt.use_gan:
            discriminator = Discriminator()
            discriminator.apply(utils.weights_init('gaussian'))
            discriminator.cuda()

            acc_discriminator = AccDiscriminator()
            acc_discriminator.apply(utils.weights_init('gaussian'))
            acc_discriminator.cuda()

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
            if opt.use_gan:
                load_checkpoint(discriminator,
                                opt.checkpoint.replace("step_", "step_disc_"))

        model_module = model
        if opt.use_gan:
            discriminator_module = discriminator
            acc_discriminator_module = acc_discriminator

        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            if opt.use_gan:
                discriminator = torch.nn.parallel.DistributedDataParallel(
                    discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                discriminator_module = discriminator.module

                acc_discriminator = torch.nn.parallel.DistributedDataParallel(
                    acc_discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                acc_discriminator_module = acc_discriminator.module

        if opt.use_gan:
            train_residual(opt,
                           train_loader,
                           model,
                           model_module,
                           gmm_model,
                           generator_model,
                           embedder_model,
                           board,
                           discriminator=discriminator,
                           discriminator_module=discriminator_module,
                           acc_discriminator=acc_discriminator,
                           acc_discriminator_module=acc_discriminator_module)

            if single_gpu_flag(opt):
                save_checkpoint(
                    {
                        "generator": model_module,
                        "discriminator": discriminator_module
                    },
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
        else:
            train_residual(opt, train_loader, model, model_module, gmm_model,
                           generator_model, embedder_model, board)
            if single_gpu_flag(opt):
                save_checkpoint(
                    model_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
    elif opt.stage == "residual_old":
        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        generator_model = UnetGenerator(25,
                                        4,
                                        6,
                                        ngf=64,
                                        norm_layer=nn.InstanceNorm2d)
        load_checkpoint(generator_model,
                        "checkpoints/tom_train_new_2/step_070000.pth")
        generator_model.cuda()

        embedder_model = Embedder()
        load_checkpoint(embedder_model,
                        "checkpoints/identity_train_64_dim/step_020000.pth")
        embedder_model = embedder_model.embedder_b.cuda()

        model = UNet(n_channels=4, n_classes=3)
        if opt.distributed:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.apply(utils.weights_init('kaiming'))
        model.cuda()

        if opt.use_gan:
            discriminator = Discriminator()
            discriminator.apply(utils.weights_init('gaussian'))
            discriminator.cuda()

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        if opt.use_gan:
            discriminator_module = discriminator
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            if opt.use_gan:
                discriminator = torch.nn.parallel.DistributedDataParallel(
                    discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                discriminator_module = discriminator.module

        if opt.use_gan:
            train_residual_old(opt,
                               train_loader,
                               model,
                               model_module,
                               gmm_model,
                               generator_model,
                               embedder_model,
                               board,
                               discriminator=discriminator,
                               discriminator_module=discriminator_module)
            if single_gpu_flag(opt):
                save_checkpoint(
                    {
                        "generator": model_module,
                        "discriminator": discriminator_module
                    },
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
        else:
            train_residual_old(opt, train_loader, model, model_module,
                               gmm_model, generator_model, embedder_model,
                               board)
            if single_gpu_flag(opt):
                save_checkpoint(
                    model_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #20
0
def main():
    opt = get_opt()

    if opt.mode == 'test':
        opt.datamode  = "test"
        opt.data_list = "test_pairs.txt"
        opt.shuffle = False
    elif opt.mode == 'val':
        opt.shuffle = False
    elif opt.mode != 'train':
        print(opt.mode)

    print(opt)

    if opt.mode != 'train':
        opt.batch_size = 1


    if opt.mode != 'train' and not opt.checkpoint:
        print("You need to have a checkpoint for: "+opt.mode)
        return None

    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))
   
    # create dataset 
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir = os.path.join(opt.tensorboard_dir, opt.name))
    
    # create model & train & save the final checkpoint
    if opt.stage == 'HPM':
        model = UnetGenerator(25, 16, 6, ngf=64, norm_layer=nn.InstanceNorm2d, clsf=True)
        d_g= Discriminator_G(opt, 16)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
            load_checkpoint(d_g, opt.checkpoint[:-9] + "dg.pth")

        if opt.mode == "train":
            train_hpm(opt, train_loader, model, d_g, board)
        else:
            test_hpm(opt, train_loader, model)

        save_checkpoints(model, d_g, os.path.join(opt.checkpoint_dir, opt.stage +'_'+ opt.name+"_final", '%s.pth'))

    elif opt.stage == 'GMM':
        #seg_unet = UnetGenerator(25, 16, 6, ngf=64, norm_layer=nn.InstanceNorm2d, clsf=True)
        model = GMM(opt, 1)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        
        if opt.mode == "train":
            train_gmm(opt, train_loader, model, board)
        else:
            test_gmm(opt, train_loader, model)
        
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    
    elif opt.stage == 'TOM':
        model = UnetGenerator(31, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint =='' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        if opt.mode == "train":
            train_tom(opt, train_loader, model, board)
        else:
            test_tom(opt, train_loader, model)
        
        save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)
  
    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        if opt.model == 'RefinedGMM':
            model = RefinedGMM(opt)
        elif opt.model == 'OneRefinedGMM':
            model = OneRefinedGMM(opt)
        else:
            raise TypeError()
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_refined_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model,
            os.path.join(opt.checkpoint_dir, opt.name,
                         'refined_gmm_final.pth'))
    elif opt.stage == 'VariGMM':
        model = VariGMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_refined_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model,
            os.path.join(opt.checkpoint_dir, opt.name,
                         'refined_gmm_final.pth'))
    elif opt.stage == 'semanticGMM':
        model = RefinedGMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_semantic_parsing_refined_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model,
            os.path.join(opt.checkpoint_dir, opt.name,
                         'refined_gmm_final.pth'))
    elif opt.stage == 'no_background_GMM':
        model = RefinedGMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_no_background_refined_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model,
            os.path.join(opt.checkpoint_dir, opt.name,
                         'no_background_refined_gmm_final.pth'))
    elif opt.stage == 'TOM':
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    elif opt.stage == 'DeepTom':
        norm_layer = 'instance'
        use_dropout = True
        with_tanh = False
        model = Define_G(25,
                         4,
                         64,
                         'treeresnet',
                         'instance',
                         True,
                         'normal',
                         0.02,
                         opt.gpu_ids,
                         with_tanh=False)
        train_deep_tom(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #22
0
        test_sampler = sampler.RandomSampler(test_dataset)
    test_loader = DataLoader(test_dataset,
                             batch_size=opt.batch_size,
                             shuffle=opt.shuffle,
                             num_workers=opt.workers,
                             pin_memory=True,
                             sampler=test_sampler)

    # visualization
    os.makedirs(opt.tensorboard_dir, exist_ok=True)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name,
                                               opt.data_list.split('.')[0]))

    # create model & train
    if opt.stage == 'GMM':
        print('Dataset size: %05d!' % (len(test_dataset)), flush=True)
        model = GMM(opt)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_gmm(opt, test_loader, model, board)
    elif opt.stage == 'TOM':
        print('Dataset size: %05d!' % (len(test_dataset)), flush=True)
        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        load_checkpoint(model, opt.checkpoint)
        with torch.no_grad():
            test_tom(opt, test_loader, model, board)
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished test %s, named: %s!' % (opt.stage, opt.name))
opt.distributed = n_gpu > 1
local_rank = opt.local_rank

# create dataset
train_dataset = CPDataset(opt)

# create dataloader
train_loader = CPDataLoader(opt, train_dataset)
data_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=opt.batch_size,
                                          shuffle=False,
                                          num_workers=opt.workers,
                                          pin_memory=True)

gmm_model = GMM(opt)
load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
gmm_model.cuda()

generator = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
load_checkpoint(generator, "checkpoints/tom_train_new_2/step_070000.pth")
generator.cuda()

embedder_model = Embedder()
load_checkpoint(embedder_model,
                "checkpoints/identity_embedding_for_test/step_045000.pth")
image_embedder = embedder_model.embedder_b.cuda()
prod_embedder = embedder_model.embedder_a.cuda()

model = G()
if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
    load_checkpoint(model, opt.checkpoint)
Example #24
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    gmm_model = GMM(opt)
    load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
    gmm_model.cuda()

    generator_model = UnetGenerator(25,
                                    4,
                                    6,
                                    ngf=64,
                                    norm_layer=nn.InstanceNorm2d)
    load_checkpoint(generator_model,
                    "checkpoints/tom_train_new_2/step_040000.pth")
    generator_model.cuda()

    embedder_model = Embedder()
    load_checkpoint(embedder_model,
                    "checkpoints/identity_train_64_dim/step_020000.pth")
    embedder_model = embedder_model.embedder_b.cuda()

    model = G()
    model.apply(utils.weights_init('kaiming'))
    model.cuda()

    if opt.use_gan:
        discriminator = Discriminator()
        discriminator.apply(utils.weights_init('gaussian'))
        discriminator.cuda()

    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
        load_checkpoint(model, opt.checkpoint)

    model_module = model
    if opt.use_gan:
        discriminator_module = discriminator
    if opt.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True)
        model_module = model.module
        if opt.use_gan:
            discriminator = torch.nn.parallel.DistributedDataParallel(
                discriminator,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            discriminator_module = discriminator.module

    if opt.use_gan:
        train_residual_old(opt,
                           train_loader,
                           model,
                           model_module,
                           gmm_model,
                           generator_model,
                           embedder_model,
                           board,
                           discriminator=discriminator,
                           discriminator_module=discriminator_module)
        if single_gpu_flag(opt):
            save_checkpoint(
                {
                    "generator": model_module,
                    "discriminator": discriminator_module
                }, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        train_residual_old(opt, train_loader, model, model_module, gmm_model,
                           generator_model, embedder_model, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))