def main(): opt = get_opt() print(opt) print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name)) n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 opt.distributed = n_gpu > 1 local_rank = opt.local_rank if opt.distributed: torch.cuda.set_device(opt.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') synchronize() # create dataset dataset = CPDataset(opt) # create dataloader loader = CPDataLoader(opt, dataset) data_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers, pin_memory=True, sampler=None) # visualization if not os.path.exists(opt.tensorboard_dir): os.makedirs(opt.tensorboard_dir) gmm_model = GMM(opt) load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth") gmm_model.cuda() generator_model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) load_checkpoint(generator_model, "checkpoints/tom_train_new_2/step_040000.pth") generator_model.cuda() embedder_model = Embedder() load_checkpoint(embedder_model, "checkpoints/identity_train_64_dim/step_020000.pth") embedder_model = embedder_model.embedder_b.cuda() model = UNet(n_channels=4, n_classes=3) model.cuda() if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): load_checkpoint(model, opt.checkpoint) test_residual(opt, data_loader, model, gmm_model, generator_model) print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
def set_vars(self): self.parse_args() self.args.manualSeed = 1 self.args.n_gpu = int( os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 self.args.distributed = self.args.n_gpu > 1 if self.args.distributed: torch.cuda.set_device(self.args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize()
def distributed_worker( local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args ): if not torch.cuda.is_available(): raise OSError("CUDA is not available. Please check your environments") global_rank = machine_rank * n_gpu_per_machine + local_rank try: dist.init_process_group( backend="NCCL", init_method=dist_url, world_size=world_size, rank=global_rank, ) except Exception: raise OSError("failed to initialize NCCL groups") dist_fn.synchronize() if n_gpu_per_machine > torch.cuda.device_count(): raise ValueError( f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})" ) torch.cuda.set_device(local_rank) if dist_fn.LOCAL_PROCESS_GROUP is not None: raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None") n_machine = world_size // n_gpu_per_machine for i in range(n_machine): ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine)) pg = dist.new_group(ranks_on_i) if i == machine_rank: dist_fn.distributed.LOCAL_PROCESS_GROUP = pg fn(*args)
def main(): opt = get_opt() print(opt) print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name)) n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 opt.distributed = n_gpu > 1 local_rank = opt.local_rank if opt.distributed: torch.cuda.set_device(opt.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') synchronize() # create dataset train_dataset = CPDataset(opt) # create dataloader train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers, pin_memory=True) # visualization if not os.path.exists(opt.tensorboard_dir): os.makedirs(opt.tensorboard_dir) board = None if single_gpu_flag(opt): board = SummaryWriter( log_dir=os.path.join(opt.tensorboard_dir, opt.name)) model = Embedder() if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): load_checkpoint(model, opt.checkpoint) test_identity_embedding(opt, train_loader, model, board) print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
parser.add_argument('--lr', type=float, default=0.002) parser.add_argument('--channel_multiplier', type=int, default=2) parser.add_argument('--wandb', action='store_true') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() print(args) n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 args.distributed = n_gpu > 1 if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') synchronize() args.latent = 512 args.n_mlp = 8 # generator = Generator( # args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier # ).to(device) generator = ProjectionGenerator("stylegan2-ffhq-config-f.pt").to(device) discriminator = Discriminator( args.size, channel_multiplier=args.channel_multiplier ).to(device) # g_ema = Generator( # args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier # ).to(device)
transform = transforms.Compose([ transforms.Resize(224), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dset = Places365(args.path, transform=transform) args.n_class = dset.n_class if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") dist.synchronize() gen = Generator(args.n_class, args.dim_z, args.dim_class).to(device) g_ema = Generator(args.n_class, args.dim_z, args.dim_class).to(device) accumulate(g_ema, gen, 0) dis = Discriminator(args.n_class).to(device) if args.ckpt is not None: ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) gen.load_state_dict(ckpt["g"]) g_ema.load_state_dict(ckpt["g_ema"]) dis.load_state_dict(ckpt["d"]) if args.distributed: gen = nn.parallel.DistributedDataParallel(
def setup_and_run(device, args): os.makedirs(f'sample_{args.name}', exist_ok=True) os.makedirs(f'checkpoint_{args.name}', exist_ok=True) n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 args.distributed = n_gpu > 1 if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') synchronize() args.latent = 512 args.n_mlp = 8 args.start_iter = 0 generator = Generator( args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device) discriminator = Discriminator( args.size, channel_multiplier=args.channel_multiplier).to(device) g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device) g_ema.eval() accumulate(g_ema, generator, 0) g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) g_optim = optim.Adam( generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0**g_reg_ratio, 0.99**g_reg_ratio), ) d_optim = optim.Adam( discriminator.parameters(), lr=args.lr * d_reg_ratio, betas=(0**d_reg_ratio, 0.99**d_reg_ratio), ) if args.ckpt is not None: print('load model:', args.ckpt) ckpt = torch.load(args.ckpt) try: ckpt_name = os.path.basename(args.ckpt) args.start_iter = int(os.path.splitext(ckpt_name)[0]) except ValueError: pass generator.load_state_dict(ckpt['g'], strict=False) discriminator.load_state_dict(ckpt['d'], strict=False) g_ema.load_state_dict(ckpt['g_ema'], strict=False) g_optim.load_state_dict(ckpt['g_optim']) d_optim.load_state_dict(ckpt['d_optim']) if args.distributed: generator = nn.parallel.DistributedDataParallel( generator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) discriminator = nn.parallel.DistributedDataParallel( discriminator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) dataset = MultiResolutionDataset(args.path, transform, args.size) loader = data.DataLoader( dataset, batch_size=args.batch, sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), drop_last=True, ) if get_rank() == 0 and wandb is not None and args.wandb: wandb.init(project='stylegan 2') train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
def main(): opt = get_opt() print(opt) print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name)) n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 opt.distributed = n_gpu > 1 local_rank = opt.local_rank if opt.distributed: torch.cuda.set_device(opt.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') synchronize() # create dataset train_dataset = CPDataset(opt) # create dataloader train_loader = CPDataLoader(opt, train_dataset) # visualization if not os.path.exists(opt.tensorboard_dir): os.makedirs(opt.tensorboard_dir) board = None if single_gpu_flag(opt): board = SummaryWriter( log_dir=os.path.join(opt.tensorboard_dir, opt.name)) gmm_model = GMM(opt) load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth") gmm_model.cuda() generator_model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) load_checkpoint(generator_model, "checkpoints/tom_train_new_2/step_040000.pth") generator_model.cuda() embedder_model = Embedder() load_checkpoint(embedder_model, "checkpoints/identity_train_64_dim/step_020000.pth") embedder_model = embedder_model.embedder_b.cuda() model = G() model.apply(utils.weights_init('kaiming')) model.cuda() if opt.use_gan: discriminator = Discriminator() discriminator.apply(utils.weights_init('gaussian')) discriminator.cuda() if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): load_checkpoint(model, opt.checkpoint) model_module = model if opt.use_gan: discriminator_module = discriminator if opt.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) model_module = model.module if opt.use_gan: discriminator = torch.nn.parallel.DistributedDataParallel( discriminator, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) discriminator_module = discriminator.module if opt.use_gan: train_residual_old(opt, train_loader, model, model_module, gmm_model, generator_model, embedder_model, board, discriminator=discriminator, discriminator_module=discriminator_module) if single_gpu_flag(opt): save_checkpoint( { "generator": model_module, "discriminator": discriminator_module }, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth')) else: train_residual_old(opt, train_loader, model, model_module, gmm_model, generator_model, embedder_model, board) if single_gpu_flag(opt): save_checkpoint( model_module, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth')) print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
def main(): opt = get_opt() print(opt) print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name)) n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 opt.distributed = n_gpu > 1 local_rank = opt.local_rank if opt.distributed: torch.cuda.set_device(opt.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') synchronize() # create dataset train_dataset = CPDataset(opt) # create dataloader train_loader = CPDataLoader(opt, train_dataset) # visualization if not os.path.exists(opt.tensorboard_dir): os.makedirs(opt.tensorboard_dir) board = None if single_gpu_flag(opt): board = SummaryWriter( log_dir=os.path.join(opt.tensorboard_dir, opt.name)) # create model & train & save the final checkpoint if opt.stage == 'GMM': model = GMM(opt) if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): load_checkpoint(model, opt.checkpoint) train_gmm(opt, train_loader, model, board) save_checkpoint( model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth')) elif opt.stage == 'TOM': gmm_model = GMM(opt) load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth") gmm_model.cuda() model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) model.cuda() # if opt.distributed: # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): load_checkpoint(model, opt.checkpoint) model_module = model if opt.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) model_module = model.module train_tom(opt, train_loader, model, model_module, gmm_model, board) if single_gpu_flag(opt): save_checkpoint( model_module, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth')) elif opt.stage == 'TOM+WARP': gmm_model = GMM(opt) gmm_model.cuda() model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) model.cuda() # if opt.distributed: # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): load_checkpoint(model, opt.checkpoint) model_module = model gmm_model_module = gmm_model if opt.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) model_module = model.module gmm_model = torch.nn.parallel.DistributedDataParallel( gmm_model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) gmm_model_module = gmm_model.module train_tom_gmm(opt, train_loader, model, model_module, gmm_model, gmm_model_module, board) if single_gpu_flag(opt): save_checkpoint( model_module, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth')) elif opt.stage == "identity": model = Embedder() if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): load_checkpoint(model, opt.checkpoint) train_identity_embedding(opt, train_loader, model, board) save_checkpoint( model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth')) elif opt.stage == 'residual': gmm_model = GMM(opt) load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth") gmm_model.cuda() generator_model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) load_checkpoint(generator_model, "checkpoints/tom_train_new/step_038000.pth") generator_model.cuda() embedder_model = Embedder() load_checkpoint(embedder_model, "checkpoints/identity_train_64_dim/step_020000.pth") embedder_model = embedder_model.embedder_b.cuda() model = UNet(n_channels=4, n_classes=3) if opt.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model.apply(utils.weights_init('kaiming')) model.cuda() if opt.use_gan: discriminator = Discriminator() discriminator.apply(utils.weights_init('gaussian')) discriminator.cuda() acc_discriminator = AccDiscriminator() acc_discriminator.apply(utils.weights_init('gaussian')) acc_discriminator.cuda() if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): load_checkpoint(model, opt.checkpoint) if opt.use_gan: load_checkpoint(discriminator, opt.checkpoint.replace("step_", "step_disc_")) model_module = model if opt.use_gan: discriminator_module = discriminator acc_discriminator_module = acc_discriminator if opt.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) model_module = model.module if opt.use_gan: discriminator = torch.nn.parallel.DistributedDataParallel( discriminator, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) discriminator_module = discriminator.module acc_discriminator = torch.nn.parallel.DistributedDataParallel( acc_discriminator, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) acc_discriminator_module = acc_discriminator.module if opt.use_gan: train_residual(opt, train_loader, model, model_module, gmm_model, generator_model, embedder_model, board, discriminator=discriminator, discriminator_module=discriminator_module, acc_discriminator=acc_discriminator, acc_discriminator_module=acc_discriminator_module) if single_gpu_flag(opt): save_checkpoint( { "generator": model_module, "discriminator": discriminator_module }, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth')) else: train_residual(opt, train_loader, model, model_module, gmm_model, generator_model, embedder_model, board) if single_gpu_flag(opt): save_checkpoint( model_module, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth')) elif opt.stage == "residual_old": gmm_model = GMM(opt) load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth") gmm_model.cuda() generator_model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) load_checkpoint(generator_model, "checkpoints/tom_train_new_2/step_070000.pth") generator_model.cuda() embedder_model = Embedder() load_checkpoint(embedder_model, "checkpoints/identity_train_64_dim/step_020000.pth") embedder_model = embedder_model.embedder_b.cuda() model = UNet(n_channels=4, n_classes=3) if opt.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model.apply(utils.weights_init('kaiming')) model.cuda() if opt.use_gan: discriminator = Discriminator() discriminator.apply(utils.weights_init('gaussian')) discriminator.cuda() if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): load_checkpoint(model, opt.checkpoint) model_module = model if opt.use_gan: discriminator_module = discriminator if opt.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) model_module = model.module if opt.use_gan: discriminator = torch.nn.parallel.DistributedDataParallel( discriminator, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) discriminator_module = discriminator.module if opt.use_gan: train_residual_old(opt, train_loader, model, model_module, gmm_model, generator_model, embedder_model, board, discriminator=discriminator, discriminator_module=discriminator_module) if single_gpu_flag(opt): save_checkpoint( { "generator": model_module, "discriminator": discriminator_module }, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth')) else: train_residual_old(opt, train_loader, model, model_module, gmm_model, generator_model, embedder_model, board) if single_gpu_flag(opt): save_checkpoint( model_module, os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth')) else: raise NotImplementedError('Model [%s] is not implemented' % opt.stage) print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
def main(): device = 'cuda' parser = argparse.ArgumentParser() parser.add_argument('path', type=str) parser.add_argument('--decay', type=float, default=0.5**(32 / (10 * 1000))) parser.add_argument('--view_id', type=int, default=0) parser.add_argument('--iter', type=int, default=10000) parser.add_argument('--batch', type=int, default=16) parser.add_argument('--val_batch', type=int, default=None) parser.add_argument('--latent', type=int, default=512) parser.add_argument('--n_mlp', type=int, default=8) # parser.add_argument('--n_sample', type=int, default=16) parser.add_argument('--size', type=int, default=64) parser.add_argument('--initial_size', type=int, default=4) parser.add_argument('--r1', type=float, default=10) parser.add_argument('--d_reg_every', type=int, default=16) parser.add_argument('--g_reg_every', type=int, default=4) # parser.add_argument('--mixing', type=float, default=0.9) parser.add_argument('--ckpt', type=str, default=None) parser.add_argument('--lr', type=float, default=0.002) parser.add_argument('--channel_multiplier', type=float, default=2) parser.add_argument('--wandb', action='store_true') # default value is true parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--ckpt_save_directory', type=str, default='checkpoint') parser.add_argument('--sample_save_directory', type=str, default='sample') parser.add_argument('--load_sil', action='store_true') # default value is false parser.add_argument('--input_type', type=str, default='silhouette') # default value is false parser.add_argument('--input_quant', action='store_true', help='make input quantized') # default value is false parser.add_argument('--exp_name', type=str, default='0') parser.add_argument('--loader_type', type=str, default='gp') parser.add_argument('--categories', default=[], nargs='+', help="what all categories from shapenet to use") parser.add_argument('--no_noise', action='store_true', help="remove noise in StyledConv layer") parser.add_argument('--max_vps', type=int, default=20, help="maximum viewpoints") parser.add_argument('--num_vps', type=int, default=2, help="sample num_vps at a time") parser.add_argument('--soft_l1', action='store_true', help="use soft l1 loss") parser.add_argument( '--random_avg', action='store_true', help="use uniformly sampled noise for averaging latent") parser.add_argument( '--use_pretrained_if_available', action='store_true', help="Use pretrained model with largest iter if available") parser.add_argument('--uncond_cat_enc', action='store_true', help="make encoder unconditional on category") parser.add_argument('--uncond_vp_enc', action='store_true', help="make encoder unconditional on viewpoint") parser.add_argument('--merge_conditions', action='store_true', help="merge viewpoint and category embedding") parser.add_argument('--bins', type=int, default=0) # default value is 0 parser.add_argument('--smoothing', type=float, default=0.2) # default value is 0.2 parser.add_argument('--seed', type=int, default=0) # Seed for dataloader args = parser.parse_args() n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 args.distributed = n_gpu > 1 if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') synchronize() # args.ckpt_save_direcory = '/scratch1/jsreddy/stylegan_depthmaps_data/'+args.ckpt_save_directory + "_" + str(args.exp_num) args.ckpt_save_directory = os.path.join(args.ckpt_save_directory, args.exp_name) if not os.path.exists(args.ckpt_save_directory): os.makedirs(args.ckpt_save_directory) print("Created Ckpt Directory: ", args.ckpt_save_directory) # args.sample_save_directory = '/scratch1/jsreddy/stylegan_depthmaps_data/'+args.sample_save_directory + "_" + str(args.exp_num) args.sample_save_directory = os.path.join(args.sample_save_directory, args.exp_name) if not os.path.exists(args.sample_save_directory): os.makedirs(args.sample_save_directory) print("Created Sample Directory: ", args.sample_save_directory) # args.latent = 512 # args.n_mlp = 8 args.start_iter = 0 if args.bins > 1: args.output_channels = args.bins args.load_sil = True elif args.load_sil: args.output_channels = 2 else: args.output_channels = 1 args.input_channels = args.output_channels if args.loader_type == 'merged': args.input_channels = args.bins if args.input_quant else 1 # 3D stuff # reproj_consist = ReprojectionConsistency( # include_self=args.include_self, use_sil=args.load_sil, device=device, data_type=args.loader_type) # Dataset # dataset = MultiResolutionDataset(args.path, transform, args.size) dataset = DepthMapDataset(args.path, categories=args.categories, loader_type=args.loader_type, load_sil=args.load_sil, view_id=args.view_id, input_type=args.input_type) # Generator call generator = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier, output_channels=args.output_channels, initial_size=args.initial_size, num_viewpoints=args.max_vps, no_noise=args.no_noise, num_categories=len(dataset.classes), merge_conditions=args.merge_conditions).to(device) print("Generator:", generator) # Encoder call encoder = Encoder( args.size, args.latent, channel_multiplier=args.channel_multiplier, input_channels=args.input_channels, num_viewpoints=1 if args.uncond_vp_enc else args.max_vps, initial_size=args.initial_size, num_categories=1 if args.uncond_cat_enc else len(dataset.classes), ).to(device) print("Encoder: ", encoder) g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier, output_channels=args.output_channels, initial_size=args.initial_size, num_viewpoints=args.max_vps, no_noise=args.no_noise, num_categories=len(dataset.classes), merge_conditions=args.merge_conditions).to(device) g_ema.eval() accumulate(g_ema, generator, 0) g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) if args.g_reg_every > 0 else 1 d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) if args.d_reg_every > 0 else 1 g_optim = optim.Adam( generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0**g_reg_ratio, 0.99**g_reg_ratio), ) d_optim = optim.Adam( encoder.parameters(), lr=args.lr * d_reg_ratio, betas=(0**d_reg_ratio, 0.99**d_reg_ratio), ) if args.ckpt is not None: print('load model:', args.ckpt) ckpt = torch.load(args.ckpt) try: ckpt_name = os.path.basename(args.ckpt) args.start_iter = int(os.path.splitext(ckpt_name)[0]) except ValueError: pass generator.load_state_dict(ckpt['g']) encoder.load_state_dict(ckpt['d']) g_ema.load_state_dict(ckpt['g_ema']) g_optim.load_state_dict(ckpt['g_optim']) d_optim.load_state_dict(ckpt['d_optim']) elif args.use_pretrained_if_available: from glob import glob files = sorted(glob(os.path.join(args.ckpt_save_directory, '*.pt'))) if len(files) > 0: print('found model:', files[-1]) ckpt = torch.load(files[-1]) ckpt_name = os.path.basename(files[-1]) args.start_iter = int(os.path.splitext(ckpt_name)[0]) generator.load_state_dict(ckpt['g']) encoder.load_state_dict(ckpt['d']) g_ema.load_state_dict(ckpt['g_ema']) g_optim.load_state_dict(ckpt['g_optim']) d_optim.load_state_dict(ckpt['d_optim']) if args.distributed: generator = nn.parallel.DistributedDataParallel( generator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) encoder = nn.parallel.DistributedDataParallel( encoder, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) if get_rank() == 0 and wandb is not None and args.wandb: print("Using wandb") wandb.init(project='vp2vp', name=str(args.exp_name)) wandb.config.update(args) print("# params in encoder:", count_parameters(encoder)) print("# params in decoder:", count_parameters(generator)) train_dataset, val_dataset, test_dataset = shapenet_splits(dataset) print( f"length of train/valid/test: {len(train_dataset)}, {len(val_dataset)}, {len(test_dataset)}" ) seed_torch(args.seed) train_loader = data.DataLoader( train_dataset, batch_size=args.batch, sampler=data_sampler(train_dataset, shuffle=True, distributed=args.distributed), drop_last=False, num_workers=4, ) val_loader = data.DataLoader( val_dataset, batch_size=args.batch if args.val_batch is None else args.val_batch, sampler=data_sampler(val_dataset, shuffle=True, distributed=args.distributed), drop_last=False, num_workers=4, ) print("Training.....") train(args, train_loader, val_loader, generator, encoder, g_ema, g_optim, d_optim, device)