def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    dataset = CPDataset(opt)

    # create dataloader
    loader = CPDataLoader(opt, dataset)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=opt.workers,
                                              pin_memory=True,
                                              sampler=None)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    gmm_model = GMM(opt)
    load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
    gmm_model.cuda()

    generator_model = UnetGenerator(25,
                                    4,
                                    6,
                                    ngf=64,
                                    norm_layer=nn.InstanceNorm2d)
    load_checkpoint(generator_model,
                    "checkpoints/tom_train_new_2/step_040000.pth")
    generator_model.cuda()

    embedder_model = Embedder()
    load_checkpoint(embedder_model,
                    "checkpoints/identity_train_64_dim/step_020000.pth")
    embedder_model = embedder_model.embedder_b.cuda()

    model = UNet(n_channels=4, n_classes=3)
    model.cuda()

    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
        load_checkpoint(model, opt.checkpoint)

    test_residual(opt, data_loader, model, gmm_model, generator_model)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #2
0
    def set_vars(self):
        self.parse_args()

        self.args.manualSeed = 1

        self.args.n_gpu = int(
            os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
        self.args.distributed = self.args.n_gpu > 1

        if self.args.distributed:
            torch.cuda.set_device(self.args.local_rank)
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")
            synchronize()
Example #3
0
def distributed_worker(
    local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args
):
    if not torch.cuda.is_available():
        raise OSError("CUDA is not available. Please check your environments")

    global_rank = machine_rank * n_gpu_per_machine + local_rank

    try:
        dist.init_process_group(
            backend="NCCL",
            init_method=dist_url,
            world_size=world_size,
            rank=global_rank,
        )

    except Exception:
        raise OSError("failed to initialize NCCL groups")

    dist_fn.synchronize()

    if n_gpu_per_machine > torch.cuda.device_count():
        raise ValueError(
            f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})"
        )

    torch.cuda.set_device(local_rank)

    if dist_fn.LOCAL_PROCESS_GROUP is not None:
        raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None")

    n_machine = world_size // n_gpu_per_machine

    for i in range(n_machine):
        ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))
        pg = dist.new_group(ranks_on_i)

        if i == machine_rank:
            dist_fn.distributed.LOCAL_PROCESS_GROUP = pg

    fn(*args)
Example #4
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.batch_size,
                                               shuffle=False,
                                               num_workers=opt.workers,
                                               pin_memory=True)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    model = Embedder()
    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
        load_checkpoint(model, opt.checkpoint)
    test_identity_embedding(opt, train_loader, model, board)
    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #5
0
    parser.add_argument('--lr', type=float, default=0.002)
    parser.add_argument('--channel_multiplier', type=int, default=2)
    parser.add_argument('--wandb', action='store_true')
    parser.add_argument('--local_rank', type=int, default=0)

    args = parser.parse_args()

    print(args)

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    args.distributed = n_gpu > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        synchronize()

    args.latent = 512
    args.n_mlp = 8

    # generator = Generator(
    #     args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
    # ).to(device)
    generator = ProjectionGenerator("stylegan2-ffhq-config-f.pt").to(device)

    discriminator = Discriminator(
        args.size, channel_multiplier=args.channel_multiplier
    ).to(device)
    # g_ema = Generator(
    #     args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
    # ).to(device)
Example #6
0
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dset = Places365(args.path, transform=transform)
    args.n_class = dset.n_class

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        dist.synchronize()

    gen = Generator(args.n_class, args.dim_z, args.dim_class).to(device)
    g_ema = Generator(args.n_class, args.dim_z, args.dim_class).to(device)
    accumulate(g_ema, gen, 0)
    dis = Discriminator(args.n_class).to(device)

    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)

        gen.load_state_dict(ckpt["g"])
        g_ema.load_state_dict(ckpt["g_ema"])
        dis.load_state_dict(ckpt["d"])

    if args.distributed:
        gen = nn.parallel.DistributedDataParallel(
Example #7
0
def setup_and_run(device, args):

    os.makedirs(f'sample_{args.name}', exist_ok=True)
    os.makedirs(f'checkpoint_{args.name}', exist_ok=True)

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    args.distributed = n_gpu > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    args.latent = 512
    args.n_mlp = 8

    args.start_iter = 0

    generator = Generator(
        args.size,
        args.latent,
        args.n_mlp,
        channel_multiplier=args.channel_multiplier).to(device)
    discriminator = Discriminator(
        args.size, channel_multiplier=args.channel_multiplier).to(device)
    g_ema = Generator(args.size,
                      args.latent,
                      args.n_mlp,
                      channel_multiplier=args.channel_multiplier).to(device)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    g_optim = optim.Adam(
        generator.parameters(),
        lr=args.lr * g_reg_ratio,
        betas=(0**g_reg_ratio, 0.99**g_reg_ratio),
    )
    d_optim = optim.Adam(
        discriminator.parameters(),
        lr=args.lr * d_reg_ratio,
        betas=(0**d_reg_ratio, 0.99**d_reg_ratio),
    )

    if args.ckpt is not None:
        print('load model:', args.ckpt)

        ckpt = torch.load(args.ckpt)

        try:
            ckpt_name = os.path.basename(args.ckpt)
            args.start_iter = int(os.path.splitext(ckpt_name)[0])

        except ValueError:
            pass

        generator.load_state_dict(ckpt['g'], strict=False)
        discriminator.load_state_dict(ckpt['d'], strict=False)
        g_ema.load_state_dict(ckpt['g_ema'], strict=False)

        g_optim.load_state_dict(ckpt['g_optim'])
        d_optim.load_state_dict(ckpt['d_optim'])

    if args.distributed:
        generator = nn.parallel.DistributedDataParallel(
            generator,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
        )

        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
        )

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    dataset = MultiResolutionDataset(args.path, transform, args.size)
    loader = data.DataLoader(
        dataset,
        batch_size=args.batch,
        sampler=data_sampler(dataset,
                             shuffle=True,
                             distributed=args.distributed),
        drop_last=True,
    )

    if get_rank() == 0 and wandb is not None and args.wandb:
        wandb.init(project='stylegan 2')

    train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device)
Example #8
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    gmm_model = GMM(opt)
    load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
    gmm_model.cuda()

    generator_model = UnetGenerator(25,
                                    4,
                                    6,
                                    ngf=64,
                                    norm_layer=nn.InstanceNorm2d)
    load_checkpoint(generator_model,
                    "checkpoints/tom_train_new_2/step_040000.pth")
    generator_model.cuda()

    embedder_model = Embedder()
    load_checkpoint(embedder_model,
                    "checkpoints/identity_train_64_dim/step_020000.pth")
    embedder_model = embedder_model.embedder_b.cuda()

    model = G()
    model.apply(utils.weights_init('kaiming'))
    model.cuda()

    if opt.use_gan:
        discriminator = Discriminator()
        discriminator.apply(utils.weights_init('gaussian'))
        discriminator.cuda()

    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
        load_checkpoint(model, opt.checkpoint)

    model_module = model
    if opt.use_gan:
        discriminator_module = discriminator
    if opt.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True)
        model_module = model.module
        if opt.use_gan:
            discriminator = torch.nn.parallel.DistributedDataParallel(
                discriminator,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            discriminator_module = discriminator.module

    if opt.use_gan:
        train_residual_old(opt,
                           train_loader,
                           model,
                           model_module,
                           gmm_model,
                           generator_model,
                           embedder_model,
                           board,
                           discriminator=discriminator,
                           discriminator_module=discriminator_module)
        if single_gpu_flag(opt):
            save_checkpoint(
                {
                    "generator": model_module,
                    "discriminator": discriminator_module
                }, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    else:
        train_residual_old(opt, train_loader, model, model_module, gmm_model,
                           generator_model, embedder_model, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #9
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'TOM':

        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.cuda()
        # if opt.distributed:
        #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module

        train_tom(opt, train_loader, model, model_module, gmm_model, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    elif opt.stage == 'TOM+WARP':

        gmm_model = GMM(opt)
        gmm_model.cuda()

        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.cuda()
        # if opt.distributed:
        #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        gmm_model_module = gmm_model
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            gmm_model = torch.nn.parallel.DistributedDataParallel(
                gmm_model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            gmm_model_module = gmm_model.module

        train_tom_gmm(opt, train_loader, model, model_module, gmm_model,
                      gmm_model_module, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))

    elif opt.stage == "identity":
        model = Embedder()
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_identity_embedding(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'residual':

        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        generator_model = UnetGenerator(25,
                                        4,
                                        6,
                                        ngf=64,
                                        norm_layer=nn.InstanceNorm2d)
        load_checkpoint(generator_model,
                        "checkpoints/tom_train_new/step_038000.pth")
        generator_model.cuda()

        embedder_model = Embedder()
        load_checkpoint(embedder_model,
                        "checkpoints/identity_train_64_dim/step_020000.pth")
        embedder_model = embedder_model.embedder_b.cuda()

        model = UNet(n_channels=4, n_classes=3)
        if opt.distributed:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.apply(utils.weights_init('kaiming'))
        model.cuda()

        if opt.use_gan:
            discriminator = Discriminator()
            discriminator.apply(utils.weights_init('gaussian'))
            discriminator.cuda()

            acc_discriminator = AccDiscriminator()
            acc_discriminator.apply(utils.weights_init('gaussian'))
            acc_discriminator.cuda()

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
            if opt.use_gan:
                load_checkpoint(discriminator,
                                opt.checkpoint.replace("step_", "step_disc_"))

        model_module = model
        if opt.use_gan:
            discriminator_module = discriminator
            acc_discriminator_module = acc_discriminator

        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            if opt.use_gan:
                discriminator = torch.nn.parallel.DistributedDataParallel(
                    discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                discriminator_module = discriminator.module

                acc_discriminator = torch.nn.parallel.DistributedDataParallel(
                    acc_discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                acc_discriminator_module = acc_discriminator.module

        if opt.use_gan:
            train_residual(opt,
                           train_loader,
                           model,
                           model_module,
                           gmm_model,
                           generator_model,
                           embedder_model,
                           board,
                           discriminator=discriminator,
                           discriminator_module=discriminator_module,
                           acc_discriminator=acc_discriminator,
                           acc_discriminator_module=acc_discriminator_module)

            if single_gpu_flag(opt):
                save_checkpoint(
                    {
                        "generator": model_module,
                        "discriminator": discriminator_module
                    },
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
        else:
            train_residual(opt, train_loader, model, model_module, gmm_model,
                           generator_model, embedder_model, board)
            if single_gpu_flag(opt):
                save_checkpoint(
                    model_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
    elif opt.stage == "residual_old":
        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        generator_model = UnetGenerator(25,
                                        4,
                                        6,
                                        ngf=64,
                                        norm_layer=nn.InstanceNorm2d)
        load_checkpoint(generator_model,
                        "checkpoints/tom_train_new_2/step_070000.pth")
        generator_model.cuda()

        embedder_model = Embedder()
        load_checkpoint(embedder_model,
                        "checkpoints/identity_train_64_dim/step_020000.pth")
        embedder_model = embedder_model.embedder_b.cuda()

        model = UNet(n_channels=4, n_classes=3)
        if opt.distributed:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.apply(utils.weights_init('kaiming'))
        model.cuda()

        if opt.use_gan:
            discriminator = Discriminator()
            discriminator.apply(utils.weights_init('gaussian'))
            discriminator.cuda()

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        if opt.use_gan:
            discriminator_module = discriminator
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            if opt.use_gan:
                discriminator = torch.nn.parallel.DistributedDataParallel(
                    discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                discriminator_module = discriminator.module

        if opt.use_gan:
            train_residual_old(opt,
                               train_loader,
                               model,
                               model_module,
                               gmm_model,
                               generator_model,
                               embedder_model,
                               board,
                               discriminator=discriminator,
                               discriminator_module=discriminator_module)
            if single_gpu_flag(opt):
                save_checkpoint(
                    {
                        "generator": model_module,
                        "discriminator": discriminator_module
                    },
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
        else:
            train_residual_old(opt, train_loader, model, model_module,
                               gmm_model, generator_model, embedder_model,
                               board)
            if single_gpu_flag(opt):
                save_checkpoint(
                    model_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
Example #10
0
def main():
    device = 'cuda'

    parser = argparse.ArgumentParser()

    parser.add_argument('path', type=str)
    parser.add_argument('--decay', type=float, default=0.5**(32 / (10 * 1000)))
    parser.add_argument('--view_id', type=int, default=0)
    parser.add_argument('--iter', type=int, default=10000)
    parser.add_argument('--batch', type=int, default=16)
    parser.add_argument('--val_batch', type=int, default=None)
    parser.add_argument('--latent', type=int, default=512)
    parser.add_argument('--n_mlp', type=int, default=8)
    # parser.add_argument('--n_sample', type=int, default=16)
    parser.add_argument('--size', type=int, default=64)
    parser.add_argument('--initial_size', type=int, default=4)
    parser.add_argument('--r1', type=float, default=10)
    parser.add_argument('--d_reg_every', type=int, default=16)
    parser.add_argument('--g_reg_every', type=int, default=4)
    # parser.add_argument('--mixing', type=float, default=0.9)
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--lr', type=float, default=0.002)
    parser.add_argument('--channel_multiplier', type=float, default=2)
    parser.add_argument('--wandb',
                        action='store_true')  # default value is true
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--ckpt_save_directory',
                        type=str,
                        default='checkpoint')
    parser.add_argument('--sample_save_directory', type=str, default='sample')
    parser.add_argument('--load_sil',
                        action='store_true')  # default value is false
    parser.add_argument('--input_type', type=str,
                        default='silhouette')  # default value is false
    parser.add_argument('--input_quant',
                        action='store_true',
                        help='make input quantized')  # default value is false
    parser.add_argument('--exp_name', type=str, default='0')
    parser.add_argument('--loader_type', type=str, default='gp')
    parser.add_argument('--categories',
                        default=[],
                        nargs='+',
                        help="what all categories from shapenet to use")
    parser.add_argument('--no_noise',
                        action='store_true',
                        help="remove noise in StyledConv layer")
    parser.add_argument('--max_vps',
                        type=int,
                        default=20,
                        help="maximum viewpoints")
    parser.add_argument('--num_vps',
                        type=int,
                        default=2,
                        help="sample num_vps at a time")
    parser.add_argument('--soft_l1',
                        action='store_true',
                        help="use soft l1 loss")
    parser.add_argument(
        '--random_avg',
        action='store_true',
        help="use uniformly sampled noise for averaging latent")
    parser.add_argument(
        '--use_pretrained_if_available',
        action='store_true',
        help="Use pretrained model with largest iter if available")
    parser.add_argument('--uncond_cat_enc',
                        action='store_true',
                        help="make encoder unconditional on category")
    parser.add_argument('--uncond_vp_enc',
                        action='store_true',
                        help="make encoder unconditional on viewpoint")
    parser.add_argument('--merge_conditions',
                        action='store_true',
                        help="merge viewpoint and category embedding")
    parser.add_argument('--bins', type=int, default=0)  # default value is 0
    parser.add_argument('--smoothing', type=float,
                        default=0.2)  # default value is 0.2
    parser.add_argument('--seed', type=int, default=0)  # Seed for dataloader

    args = parser.parse_args()

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    args.distributed = n_gpu > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # args.ckpt_save_direcory = '/scratch1/jsreddy/stylegan_depthmaps_data/'+args.ckpt_save_directory + "_" + str(args.exp_num)
    args.ckpt_save_directory = os.path.join(args.ckpt_save_directory,
                                            args.exp_name)
    if not os.path.exists(args.ckpt_save_directory):
        os.makedirs(args.ckpt_save_directory)
        print("Created Ckpt Directory: ", args.ckpt_save_directory)

    # args.sample_save_directory = '/scratch1/jsreddy/stylegan_depthmaps_data/'+args.sample_save_directory + "_" + str(args.exp_num)
    args.sample_save_directory = os.path.join(args.sample_save_directory,
                                              args.exp_name)
    if not os.path.exists(args.sample_save_directory):
        os.makedirs(args.sample_save_directory)
        print("Created Sample Directory: ", args.sample_save_directory)

    # args.latent = 512
    # args.n_mlp = 8
    args.start_iter = 0
    if args.bins > 1:
        args.output_channels = args.bins
        args.load_sil = True
    elif args.load_sil:
        args.output_channels = 2
    else:
        args.output_channels = 1

    args.input_channels = args.output_channels
    if args.loader_type == 'merged':
        args.input_channels = args.bins if args.input_quant else 1

    # 3D stuff
    # reproj_consist = ReprojectionConsistency(
    # 	include_self=args.include_self, use_sil=args.load_sil, device=device, data_type=args.loader_type)

    # Dataset
    # dataset = MultiResolutionDataset(args.path, transform, args.size)
    dataset = DepthMapDataset(args.path,
                              categories=args.categories,
                              loader_type=args.loader_type,
                              load_sil=args.load_sil,
                              view_id=args.view_id,
                              input_type=args.input_type)

    # Generator call
    generator = Generator(args.size,
                          args.latent,
                          args.n_mlp,
                          channel_multiplier=args.channel_multiplier,
                          output_channels=args.output_channels,
                          initial_size=args.initial_size,
                          num_viewpoints=args.max_vps,
                          no_noise=args.no_noise,
                          num_categories=len(dataset.classes),
                          merge_conditions=args.merge_conditions).to(device)

    print("Generator:", generator)

    # Encoder call
    encoder = Encoder(
        args.size,
        args.latent,
        channel_multiplier=args.channel_multiplier,
        input_channels=args.input_channels,
        num_viewpoints=1 if args.uncond_vp_enc else args.max_vps,
        initial_size=args.initial_size,
        num_categories=1 if args.uncond_cat_enc else len(dataset.classes),
    ).to(device)

    print("Encoder: ", encoder)

    g_ema = Generator(args.size,
                      args.latent,
                      args.n_mlp,
                      channel_multiplier=args.channel_multiplier,
                      output_channels=args.output_channels,
                      initial_size=args.initial_size,
                      num_viewpoints=args.max_vps,
                      no_noise=args.no_noise,
                      num_categories=len(dataset.classes),
                      merge_conditions=args.merge_conditions).to(device)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    g_reg_ratio = args.g_reg_every / (args.g_reg_every +
                                      1) if args.g_reg_every > 0 else 1
    d_reg_ratio = args.d_reg_every / (args.d_reg_every +
                                      1) if args.d_reg_every > 0 else 1

    g_optim = optim.Adam(
        generator.parameters(),
        lr=args.lr * g_reg_ratio,
        betas=(0**g_reg_ratio, 0.99**g_reg_ratio),
    )
    d_optim = optim.Adam(
        encoder.parameters(),
        lr=args.lr * d_reg_ratio,
        betas=(0**d_reg_ratio, 0.99**d_reg_ratio),
    )

    if args.ckpt is not None:
        print('load model:', args.ckpt)
        ckpt = torch.load(args.ckpt)

        try:
            ckpt_name = os.path.basename(args.ckpt)
            args.start_iter = int(os.path.splitext(ckpt_name)[0])

        except ValueError:
            pass

        generator.load_state_dict(ckpt['g'])
        encoder.load_state_dict(ckpt['d'])
        g_ema.load_state_dict(ckpt['g_ema'])

        g_optim.load_state_dict(ckpt['g_optim'])
        d_optim.load_state_dict(ckpt['d_optim'])

    elif args.use_pretrained_if_available:
        from glob import glob

        files = sorted(glob(os.path.join(args.ckpt_save_directory, '*.pt')))
        if len(files) > 0:
            print('found model:', files[-1])
            ckpt = torch.load(files[-1])
            ckpt_name = os.path.basename(files[-1])
            args.start_iter = int(os.path.splitext(ckpt_name)[0])

            generator.load_state_dict(ckpt['g'])
            encoder.load_state_dict(ckpt['d'])
            g_ema.load_state_dict(ckpt['g_ema'])

            g_optim.load_state_dict(ckpt['g_optim'])
            d_optim.load_state_dict(ckpt['d_optim'])

    if args.distributed:
        generator = nn.parallel.DistributedDataParallel(
            generator,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
        )

        encoder = nn.parallel.DistributedDataParallel(
            encoder,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
        )

    if get_rank() == 0 and wandb is not None and args.wandb:
        print("Using wandb")
        wandb.init(project='vp2vp', name=str(args.exp_name))
        wandb.config.update(args)

    print("# params in encoder:", count_parameters(encoder))
    print("# params in decoder:", count_parameters(generator))

    train_dataset, val_dataset, test_dataset = shapenet_splits(dataset)
    print(
        f"length of train/valid/test: {len(train_dataset)}, {len(val_dataset)}, {len(test_dataset)}"
    )
    seed_torch(args.seed)
    train_loader = data.DataLoader(
        train_dataset,
        batch_size=args.batch,
        sampler=data_sampler(train_dataset,
                             shuffle=True,
                             distributed=args.distributed),
        drop_last=False,
        num_workers=4,
    )
    val_loader = data.DataLoader(
        val_dataset,
        batch_size=args.batch if args.val_batch is None else args.val_batch,
        sampler=data_sampler(val_dataset,
                             shuffle=True,
                             distributed=args.distributed),
        drop_last=False,
        num_workers=4,
    )

    print("Training.....")
    train(args, train_loader, val_loader, generator, encoder, g_ema, g_optim,
          d_optim, device)