def main(): args = parser.parse_args() config = get_config(args.config) # CUDA configuration cuda = config['cuda'] device_ids = config['gpu_ids'] if cuda: os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( str(i) for i in device_ids) device_ids = list(range(len(device_ids))) config['gpu_ids'] = device_ids cudnn.benchmark = True print("Arguments: {}".format(args)) # Set random seed if args.seed is None: args.seed = random.randint(1, 10000) print("Random seed: {}".format(args.seed)) random.seed(args.seed) torch.manual_seed(args.seed) if cuda: torch.cuda.manual_seed_all(args.seed) print("Configuration: {}".format(config)) try: # for unexpected error logging with torch.no_grad(): # enter no grad context if is_image_file(args.image): if args.mask and is_image_file(args.mask): # Test a single masked image with a given mask x = default_loader(args.image) mask = default_loader(args.mask) x = transforms.Resize(config['image_shape'][:-1])(x) x = transforms.CenterCrop(config['image_shape'][:-1])(x) mask = transforms.Resize(config['image_shape'][:-1])(mask) mask = transforms.CenterCrop( config['image_shape'][:-1])(mask) x = transforms.ToTensor()(x) mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0) x = normalize(x) x = x * (1. - mask) x = x.unsqueeze(dim=0) mask = mask.unsqueeze(dim=0) elif args.mask: raise TypeError( "{} is not an image file.".format(args.mask)) else: # Test a single ground-truth image with a random mask ground_truth = default_loader(args.image) ground_truth = transforms.Resize( config['image_shape'][:-1])(ground_truth) ground_truth = transforms.CenterCrop( config['image_shape'][:-1])(ground_truth) ground_truth = transforms.ToTensor()(ground_truth) ground_truth = normalize(ground_truth) ground_truth = ground_truth.unsqueeze(dim=0) bboxes = random_bbox( config, batch_size=ground_truth.size(0)) x, mask = mask_image(ground_truth, bboxes, config) # Set checkpoint path if not args.checkpoint_path: checkpoint_path = os.path.join('checkpoints', config['dataset_name'], config['mask_type'] + '_' + config['expname']) else: checkpoint_path = args.checkpoint_path # Define the trainer netG = Generator(config['netG'], cuda, device_ids) # Resume weight last_model_name = get_model_list( checkpoint_path, "gen", iteration=args.iter) netG.load_state_dict(torch.load(last_model_name)) model_iteration = int(last_model_name[-11:-3]) print("Resume from {} at iteration {}".format( checkpoint_path, model_iteration)) if cuda: netG = nn.parallel.DataParallel( netG, device_ids=device_ids) x = x.cuda() mask = mask.cuda() # Inference x1, x2, offset_flow = netG(x, mask) inpainted_result = x2 * mask + x * (1. - mask) vutils.save_image(inpainted_result, args.output, padding=0, normalize=True) print("Saved the inpainted result to {}".format(args.output)) if args.flow: vutils.save_image(offset_flow, args.flow, padding=0, normalize=True) print("Saved offset flow to {}".format(args.flow)) else: raise TypeError("{} is not an image file.".format) # exit no grad context except Exception as e: # for unexpected error logging print("Error: {}".format(e)) raise e
def generate(img, img_mask_path, model_path): with torch.no_grad(): # enter no grad context if img_mask_path and is_image_file(img_mask_path): # Test a single masked image with a given mask x = Image.fromarray(img) mask = default_loader(img_mask_path) x = transforms.Resize(config['image_shape'][:-1])(x) x = transforms.CenterCrop(config['image_shape'][:-1])(x) mask = transforms.Resize(config['image_shape'][:-1])(mask) mask = transforms.CenterCrop(config['image_shape'][:-1])(mask) x = transforms.ToTensor()(x) mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0) x = normalize(x) x = x * (1. - mask) x = x.unsqueeze(dim=0) mask = mask.unsqueeze(dim=0) elif img_mask_path: raise TypeError("{} is not an image file.".format(img_mask_path)) else: # Test a single ground-truth image with a random mask #ground_truth = default_loader(img_path) ground_truth = img ground_truth = transforms.Resize(config['image_shape'][:-1])(ground_truth) ground_truth = transforms.CenterCrop(config['image_shape'][:-1])(ground_truth) ground_truth = transforms.ToTensor()(ground_truth) ground_truth = normalize(ground_truth) ground_truth = ground_truth.unsqueeze(dim=0) bboxes = random_bbox(config, batch_size=ground_truth.size(0)) x, mask = mask_image(ground_truth, bboxes, config) # Set checkpoint path if not model_path: checkpoint_path = os.path.join('checkpoints', config['dataset_name'], config['mask_type'] + '_' + config['expname']) else: checkpoint_path = model_path # Define the trainer netG = Generator(config['netG'], cuda, device_ids) # Resume weight last_model_name = get_model_list(checkpoint_path, "gen", iteration=0) if cuda: netG.load_state_dict(torch.load(last_model_name)) else: netG.load_state_dict(torch.load(last_model_name, map_location='cpu')) model_iteration = int(last_model_name[-11:-3]) print("Resume from {} at iteration {}".format(checkpoint_path, model_iteration)) if cuda: netG = nn.parallel.DataParallel(netG, device_ids=device_ids) x = x.cuda() mask = mask.cuda() # Inference x1, x2, offset_flow = netG(x, mask) inpainted_result = x2 * mask + x * (1. - mask) inpainted_result = from_torch_img_to_numpy(inpainted_result, 'output.png', padding=0, normalize=True) return inpainted_result
def main(): args = parser.parse_args() config = get_config(args.config) # CUDA configuration cuda = config['cuda'] device_ids = config['gpu_ids'] if cuda: os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( str(i) for i in device_ids) device_ids = list(range(len(device_ids))) config['gpu_ids'] = device_ids cudnn.benchmark = True # Configure checkpoint path checkpoint_path = os.path.join( 'checkpoints', config['dataset_name'], config['mask_type'] + '_' + config['expname']) if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) shutil.copy(args.config, os.path.join(checkpoint_path, os.path.basename(args.config))) # embed() writer = SummaryWriter(logdir=checkpoint_path) logger = get_logger( checkpoint_path) # get logger and configure it at the first call logger.info(f"Arguments: {args}") # Set random seed if args.seed is None: args.seed = random.randint(1, 10000) logger.info(f"Random seed: {args.seed}") random.seed(args.seed) torch.manual_seed(args.seed) if cuda: torch.cuda.manual_seed_all(args.seed) # Log the configuration logger.info(f"Configuration: {config}") # embed() try: # for unexpected error logging # Load the dataset logger.info(f"Training on dataset: {config['dataset_name']}") train_dataset = Dataset(data_path=config['train_data_path'], with_subfolder=config['data_with_subfolder'], image_shape=config['image_shape'], random_crop=config['random_crop']) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers']) # Define the trainer trainer = Trainer(config) # logger.info(f"\n{trainer.netG}") # logger.info(f"\n{trainer.localD}") # logger.info(f"\n{trainer.globalD}") # if cuda: # trainer = nn.parallel.DataParallel(trainer, device_ids=device_ids) # trainer_module = trainer.module # else: trainer_module = trainer # Get the resume iteration to restart training start_iteration = trainer_module.resume( config['resume']) if config['resume'] else 1 iterable_train_loader = iter(train_loader) time_count = time.time() epoch = 1 for iteration in range(start_iteration, config['niter'] + 1): try: ground_truth = iterable_train_loader.next() except StopIteration: logger.info(f"Epoch: {epoch}") epoch += 1 iterable_train_loader = iter(train_loader) ground_truth = iterable_train_loader.next() # Prepare the inputs bboxes = random_bbox(config, batch_size=ground_truth.size(0)) x, mask = mask_image(ground_truth, bboxes, config) if cuda: x = x.cuda() mask = mask.cuda() ground_truth = ground_truth.cuda() ###### Forward pass ###### losses, inpainted_result, offset_flow = trainer( x, bboxes, mask, ground_truth) # Scalars from different devices are gathered into vectors for k in losses.keys(): if not losses[k].dim() == 0: losses[k] = torch.mean(losses[k]) ###### Backward pass ###### # Update D trainer_module.optimizer_d.zero_grad() losses['d'] = losses[ 'wgan_d'] + losses['wgan_gp'] * config['wgan_gp_lambda'] losses['d'].backward() trainer_module.optimizer_d.step() # Update G trainer_module.optimizer_g.zero_grad() losses['g'] = losses['l1'] * config['l1_loss_alpha'] \ + losses['ae'] * config['ae_loss_alpha'] \ + losses['wgan_g'] * config['gan_loss_alpha'] losses['g'].backward() trainer_module.optimizer_g.step() # Log and visualization log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd'] message = 'Iter: [%d/%d] ' % (iteration, config['niter']) for k in log_losses: v = losses[k] message += '%s: %.6f ' % (k, v) print(f"\r{message}", end="") if iteration % config['print_iter'] == 0: print("") time_count = time.time() - time_count speed = config['print_iter'] / time_count speed_msg = 'speed: %.2f batches/s ' % speed time_count = time.time() message += speed_msg logger.info(message) if iteration % (config['viz_iter']) == 0: viz_max_out = config['viz_max_out'] if x.size(0) > viz_max_out: viz_images = torch.stack([ x[:viz_max_out], inpainted_result[:viz_max_out], offset_flow[:viz_max_out] ], dim=1) else: viz_images = torch.stack( [x, inpainted_result, offset_flow], dim=1) viz_images = viz_images.view(-1, *list(x.size())[1:]) vutils.save_image(viz_images, '%s/niter_%03d.png' % (checkpoint_path, iteration), nrow=3 * 4, normalize=True) # Save the model if iteration % config['snapshot_save_iter'] == 0: trainer_module.save_model(checkpoint_path, iteration) except Exception as e: # for unexpected error logging logger.error(f"{e}") raise e
def train(config, logger, checkpoint_path): try: # for unexpected error logging # Load the dataset logger.info("Training on dataset: {}".format(config['dataset_name'])) train_dataset = Dataset(data_path=config['train_data_path'], with_subfolder=config['data_with_subfolder'], image_shape=config['image_shape'], random_crop=config['random_crop']) # val_dataset = Dataset(data_path=config['val_data_path'], # with_subfolder=config['data_with_subfolder'], # image_size=config['image_size'], # random_crop=config['random_crop']) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers']) # val_loader = torch.utils.data.DataLoader(dataset=val_dataset, # batch_size=config['batch_size'], # shuffle=False, # num_workers=config['num_workers']) # Define the trainer trainer = Trainer(config) logger.info("\n{}".format(trainer.netG)) logger.info("\n{}".format(trainer.localD)) logger.info("\n{}".format(trainer.globalD)) # if cuda: # trainer = nn.parallel.DataParallel(trainer, device_ids=device_ids) # trainer_module = trainer.module # else: # trainer_module = trainer trainer_module = trainer # Get the resume iteration to restart training # start_iteration = trainer_module.resume(config['resume']) if config['resume'] else 1 print("\n\nStarting epoch: ", start_iteration) iterable_train_loader = iter(train_loader) time_count = time.time() epochs = config['niter'] + 1 pbar = tqdm(range(start_iteration, epochs), dynamic_ncols=True, smoothing=0.01) # for iteration in range(start_iteration, epochs): for iteration in pbar: try: ground_truth = next(iterable_train_loader) except StopIteration: iterable_train_loader = iter(train_loader) ground_truth = next(iterable_train_loader) # Prepare the inputs bboxes = random_bbox(config, batch_size=ground_truth.size(0)) x, mask = mask_image(ground_truth, bboxes, config) x = x.cuda() mask = mask.cuda() ground_truth = ground_truth.cuda() ###### Forward pass ###### compute_g_loss = iteration % config['n_critic'] == 0 losses, inpainted_result, offset_flow = trainer(x, bboxes, mask, ground_truth, compute_g_loss) # Scalars from different devices are gathered into vectors for k in losses.keys(): if not losses[k].dim() == 0: losses[k] = torch.mean(losses[k]) ###### Backward pass ###### # Update D if not compute_g_loss: trainer_module.optimizer_d.zero_grad() losses['d'] = losses['wgan_d'] + losses['wgan_gp'] * config['wgan_gp_lambda'] losses['d'].backward() trainer_module.optimizer_d.step() # Update G if compute_g_loss: trainer_module.optimizer_g.zero_grad() losses['g'] = losses['l1'] * config['l1_loss_alpha'] \ + losses['ae'] * config['ae_loss_alpha'] \ + losses['wgan_g'] * config['gan_loss_alpha'] losses['g'].backward() trainer_module.optimizer_g.step() ### TODO: ### - Why does this need to be moved from above to here? ### # losses['d'].backward() # trainer_module.optimizer_d.step() # Set tqdm description # log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd'] # message = 'Iter: [%d/%d] ' % (iteration, config['niter']) message = ' ' for k in log_losses: v = losses.get(k, 0.) writer.add_scalar(k, v, iteration) message += '%s: %.4f ' % (k, v) pbar.set_description( ( f" {message}" ) ) # Log and visualization log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd'] if iteration % config['print_iter'] == 0: time_count = time.time() - time_count speed = config['print_iter'] / time_count speed_msg = 'speed: %.2f batches/s ' % speed time_count = time.time() message = 'Iter: [%d/%d] ' % (iteration, config['niter']) for k in log_losses: v = losses.get(k, 0.) writer.add_scalar(k, v, iteration) message += '%s: %.6f ' % (k, v) message += speed_msg # logger.info(message) if iteration % (config['viz_iter']) == 0: viz_max_out = config['viz_max_out'] if x.size(0) > viz_max_out: viz_images = torch.stack([x[:viz_max_out], inpainted_result[:viz_max_out], offset_flow[:viz_max_out]], dim=1) else: viz_images = torch.stack([x, inpainted_result, offset_flow], dim=1) viz_images = viz_images.view(-1, *list(x.size())[1:]) vutils.save_image(viz_images, '%s/niter_%03d.png' % (checkpoint_path, iteration), nrow=3 * 4, normalize=True) # Save the model if iteration % config['snapshot_save_iter'] == 0: trainer_module.save_model(checkpoint_path, iteration) except Exception as e: # for unexpected error logging logger.error("{}".format(e)) raise e
def train_distributed(config, logger, writer, checkpoint_path): dist.init_process_group( backend='nccl', # backend='gloo', init_method='env://' ) # Find out what GPU on this compute node. # local_rank = torch.distributed.get_rank() # this is the total # of GPUs across all nodes # if using 2 nodes with 4 GPUs each, world size is 8 # world_size = torch.distributed.get_world_size() print("### global rank of curr node: {} of {}".format(local_rank, world_size)) # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. # print("local_rank: ", local_rank) # dist.barrier() torch.cuda.set_device(local_rank) # Define the trainer print("Creating models on device: ", local_rank) input_dim = config['netG']['input_dim'] cnum = config['netG']['ngf'] use_cuda = True gated = config['netG']['gated'] # Models # netG = Generator(config['netG'], use_cuda=True, device=local_rank).cuda() netG = torch.nn.parallel.DistributedDataParallel( netG, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True ) localD = LocalDis(config['netD'], use_cuda=True, device_id=local_rank).cuda() localD = torch.nn.parallel.DistributedDataParallel( localD, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True ) globalD = GlobalDis(config['netD'], use_cuda=True, device_id=local_rank).cuda() globalD = torch.nn.parallel.DistributedDataParallel( globalD, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True ) if local_rank == 0: logger.info("\n{}".format(netG)) logger.info("\n{}".format(localD)) logger.info("\n{}".format(globalD)) # Optimizers # optimizer_g = torch.optim.Adam( netG.parameters(), lr=config['lr'], betas=(config['beta1'], config['beta2']) ) d_params = list(localD.parameters()) + list(globalD.parameters()) optimizer_d = torch.optim.Adam( d_params, lr=config['lr'], betas=(config['beta1'], config['beta2']) ) # Data # sampler = None train_dataset = Dataset( data_path=config['train_data_path'], with_subfolder=config['data_with_subfolder'], image_shape=config['image_shape'], random_crop=config['random_crop'] ) sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, # num_replicas=torch.cuda.device_count(), num_replicas=len(config['gpu_ids']), # rank = local_rank ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=config['batch_size'], shuffle=(sampler is None), num_workers=config['num_workers'], pin_memory=True, sampler=sampler, drop_last=True ) # Get the resume iteration to restart training # # start_iteration = trainer.resume(config['resume']) if config['resume'] else 1 start_iteration = 1 print("\n\nStarting epoch: ", start_iteration) iterable_train_loader = iter(train_loader) if local_rank == 0: time_count = time.time() epochs = config['niter'] + 1 pbar = tqdm(range(start_iteration, epochs), dynamic_ncols=True, smoothing=0.01) for iteration in pbar: sampler.set_epoch(iteration) try: ground_truth = next(iterable_train_loader) except StopIteration: iterable_train_loader = iter(train_loader) ground_truth = next(iterable_train_loader) # Prepare the inputs bboxes = random_bbox(config, batch_size=ground_truth.size(0)) x, mask = mask_image(ground_truth, bboxes, config) # Move to proper device. # bboxes = bboxes.cuda(local_rank) x = x.cuda(local_rank) mask = mask.cuda(local_rank) ground_truth = ground_truth.cuda(local_rank) ###### Forward pass ###### compute_g_loss = iteration % config['n_critic'] == 0 # losses, inpainted_result, offset_flow = forward(config, x, bboxes, mask, ground_truth, # localD=localD, globalD=globalD, # coarse_gen=coarse_generator, fine_gen=fine_generator, # local_rank=local_rank, compute_loss_g=compute_g_loss) losses, inpainted_result, offset_flow = forward(config, x, bboxes, mask, ground_truth, netG=netG, localD=localD, globalD=globalD, local_rank=local_rank, compute_loss_g=compute_g_loss) # Scalars from different devices are gathered into vectors # for k in losses.keys(): if not losses[k].dim() == 0: losses[k] = torch.mean(losses[k]) ###### Backward pass ###### # Update D if not compute_g_loss: optimizer_d.zero_grad() losses['d'] = losses['wgan_d'] + losses['wgan_gp'] * config['wgan_gp_lambda'] losses['d'].backward() optimizer_d.step() # Update G if compute_g_loss: optimizer_g.zero_grad() losses['g'] = losses['ae'] * config['ae_loss_alpha'] losses['g'] += losses['l1'] * config['l1_loss_alpha'] losses['g'] += losses['wgan_g'] * config['gan_loss_alpha'] losses['g'].backward() optimizer_g.step() # Set tqdm description # if local_rank == 0: log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd'] message = ' ' for k in log_losses: v = losses.get(k, 0.) writer.add_scalar(k, v, iteration) message += '%s: %.4f ' % (k, v) pbar.set_description( ( f" {message}" ) ) if local_rank == 0: if iteration % (config['viz_iter']) == 0: viz_max_out = config['viz_max_out'] if x.size(0) > viz_max_out: viz_images = torch.stack([x[:viz_max_out], inpainted_result[:viz_max_out], offset_flow[:viz_max_out]], dim=1) else: viz_images = torch.stack([x, inpainted_result, offset_flow], dim=1) viz_images = viz_images.view(-1, *list(x.size())[1:]) vutils.save_image(viz_images, '%s/niter_%08d.png' % (checkpoint_path, iteration), nrow=3 * 4, normalize=True) # Save the model if iteration % config['snapshot_save_iter'] == 0: save_model( netG, globalD, localD, optimizer_g, optimizer_d, checkpoint_path, iteration )
def main(): args = parser.parse_args() config = get_config(args.config) # way to use config # CUDA configuration cuda = config['cuda'] # specify cuda in config.yaml if torch.cuda.device_count() > 0: cuda = True config['cuda'] = True device_ids = config['gpu_ids'] if cuda: os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( str(i) for i in device_ids) device_ids = list(range(len(device_ids))) config['gpu_ids'] = device_ids cudnn.benchmark = True # Configure checkpoint path # say you use imagenet pretrained model, since the dataset name is altered to be "dtd" # the trained checkpoint of this model will be stored in checkpoints/dtd/hole_benchmark checkpoint_path = os.path.join( 'checkpoints', config['dataset_name'], config['mask_type'] + '_' + config['expname']) if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) shutil.copy(args.config, os.path.join(checkpoint_path, os.path.basename(args.config))) writer = SummaryWriter(logdir=checkpoint_path) logger = get_logger( checkpoint_path) # get logger and configure it at the first call logger.info("Arguments: {}".format(args)) # Set random seed if args.seed is None: args.seed = random.randint(1, 10000) logger.info("Random seed: {}".format(args.seed)) random.seed(args.seed) torch.manual_seed(args.seed) if cuda: torch.cuda.manual_seed_all(args.seed) # Log the configuration logger.info("Configuration: {}".format(config)) try: # for unexpected error logging # Load the dataset logger.info("Training on dataset: {}".format(config['dataset_name'])) train_dataset = Dataset(data_path=config['train_data_path'], with_subfolder=config['data_with_subfolder'], image_shape=config['image_shape'], random_crop=config['random_crop']) # val_dataset = Dataset(data_path=config['val_data_path'], # with_subfolder=config['data_with_subfolder'], # image_size=config['image_size'], # random_crop=config['random_crop']) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers']) # val_loader = torch.utils.data.DataLoader(dataset=val_dataset, # batch_size=config['batch_size'], # shuffle=False, # num_workers=config['num_workers']) # Define the trainer trainer = Trainer(config) logger.info("\n{}".format(trainer.netG)) logger.info("\n{}".format(trainer.localD)) logger.info("\n{}".format(trainer.globalD)) if cuda: trainer = nn.parallel.DataParallel(trainer, device_ids=device_ids) trainer_module = trainer.module else: trainer_module = trainer # Get the resume iteration to restart training # config['resume'] being the directory to checkpoints such as checkpoints/imagenet/hole_benchmark/ start_iteration = trainer_module.resume( config['resume']) if config['resume'] else 1 print("start at iteration {}".format(start_iteration)) iterable_train_loader = iter(train_loader) time_count = time.time() for iteration in range(start_iteration, config['niter'] + 1): try: ground_truth = next(iterable_train_loader) except StopIteration: iterable_train_loader = iter(train_loader) ground_truth = next(iterable_train_loader) # Prepare the inputs bboxes = random_bbox(config, batch_size=ground_truth.size(0)) x, mask = mask_image(ground_truth, bboxes, config) if cuda: x = x.cuda() mask = mask.cuda() ground_truth = ground_truth.cuda() ###### Forward pass ###### compute_g_loss = iteration % config['n_critic'] == 0 losses, inpainted_result, offset_flow = trainer( x, bboxes, mask, ground_truth, compute_g_loss) # Scalars from different devices are gathered into vectors for k in losses.keys(): if not losses[k].dim() == 0: losses[k] = torch.mean(losses[k]) ###### Backward pass ###### # Update D trainer_module.optimizer_d.zero_grad() losses['d'] = losses[ 'wgan_d'] + losses['wgan_gp'] * config['wgan_gp_lambda'] losses['d'].backward() trainer_module.optimizer_d.step() # Update G if compute_g_loss: trainer_module.optimizer_g.zero_grad() losses['g'] = losses['l1'] * config['l1_loss_alpha'] \ + losses['ae'] * config['ae_loss_alpha'] \ + losses['wgan_g'] * config['gan_loss_alpha'] losses['g'].backward() trainer_module.optimizer_g.step() # Log and visualization log_losses = ['l1', 'ae', 'wgan_g', 'wgan_d', 'wgan_gp', 'g', 'd'] if iteration % config['print_iter'] == 0: time_count = time.time() - time_count speed = config['print_iter'] / time_count speed_msg = 'speed: %.2f batches/s ' % speed time_count = time.time() message = 'Iter: [%d/%d] ' % (iteration, config['niter']) for k in log_losses: v = losses.get(k, 0.) writer.add_scalar(k, v, iteration) message += '%s: %.6f ' % (k, v) message += speed_msg logger.info(message) if iteration % (config['viz_iter']) == 0: viz_max_out = config['viz_max_out'] if x.size(0) > viz_max_out: viz_images = torch.stack([ x[:viz_max_out], inpainted_result[:viz_max_out], offset_flow[:viz_max_out] ], dim=1) else: viz_images = torch.stack( [x, inpainted_result, offset_flow], dim=1) viz_images = viz_images.view(-1, *list(x.size())[1:]) vutils.save_image(viz_images, '%s/niter_%03d.png' % (checkpoint_path, iteration), nrow=3 * 4, normalize=True) # Save the model if iteration % config['snapshot_save_iter'] == 0: trainer_module.save_model(checkpoint_path, iteration) # checkpoint_path is checkpoints/dtd/hole_benchmark # then you should use this check point to test raven score then store them in a file # if iteration % (config['snapshot_save_iter']//10) == 0: # os.system("python test_raven.py --checkpoint_path {} >> tmp.out".format(checkpoint_path)) # this could cause trouble if it's not python but python3 etc. except Exception as e: # for unexpected error logging logger.error("{}".format(e)) raise e
elif args.mask: raise TypeError("{} is not an image file.".format(args.mask)) else: # Test a single ground-truth image with a random mask ground_truth = tif_loader(args.image) #ground_truth = transforms.Resize(config['image_shape'][:-1])(ground_truth) #ground_truth = transforms.CenterCrop(config['image_shape'][:-1])(ground_truth) #ground_truth = transfer2tensor(ground_truth) #ground_truth = transforms.ToTensor()(ground_truth) x = ground_truth[110:366, 110:366, :] x = np.transpose(x, [2, 0, 1]) #print('Output min, max',x.min(),x.max()) ground_truth = torch.from_numpy(x) ground_truth = normalize(ground_truth) ground_truth = ground_truth.unsqueeze(dim=0) bboxes = random_bbox(config, batch_size=ground_truth.size(0)) x, mask = mask_image(ground_truth, bboxes, config) # Set checkpoint path if not args.checkpoint_path: checkpoint_path = os.path.join( 'checkpoints', config['dataset_name'], config['mask_type'] + '_' + config['expname']) else: checkpoint_path = args.checkpoint_path # Define the trainer netG = Generator(config['netG'], cuda, device_ids) # Resume weight last_model_name = get_model_list(checkpoint_path, "gen",