def main(config_path, experiment_path): # ARGS masks_path = None training = True # load config code_path = '/' config, pretty_config = get_config(os.path.join(code_path, config_path)) config['path']['experiment'] = os.path.join(experiment_path, config['path']['experiment']) print('\nModel configurations:'\ '\n---------------------------------\n'\ + pretty_config +\ '\n---------------------------------\n') os.environ['CUDA_VISIBLE_DEVICES'] = config['gpu'] # Import Torch after os env import torch import torchvision from torch import nn from torch.utils.tensorboard import SummaryWriter from torchvision.utils import save_image # init device if config['gpu'] and torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cudnn.benchmark = True # cudnn auto-tuner else: device = torch.device("cpu") # initialize random seed torch.manual_seed(config["seed"]) torch.cuda.manual_seed_all(config["seed"]) if not training: np.random.seed(config["seed"]) random.seed(config["seed"]) # parse args images_path = config['path']['train'] checkpoint = config['path']['experiment'] discriminator = config['training']['discriminator'] # initialize log writer logger = SummaryWriter(log_dir=config['path']['experiment']) # build the model and initialize inpainting_model = InpaintingModel(config).to(device) if checkpoint: inpainting_model.load() pred_directory = os.path.join(checkpoint, 'predictions') if not os.path.exists(pred_directory): os.makedirs(pred_directory) # generator training if training: print('\nStart training...\n') batch_size = config['training']['batch_size'] # create dataset dataset = Dataset(config, training=True) train_loader = dataset.create_iterator(batch_size) test_dataset = Dataset(config, training=False) # Train the generator total = len(dataset) if total == 0: raise Exception("Dataset is empty!") # Training loop epoch = 0 for i, items in enumerate(train_loader): inpainting_model.train() if i % total == 0: epoch += 1 print('Epoch', epoch) progbar = Progbar(total, width=20, stateful_metrics=['iter']) images, masks, constant_mask = items['image'], items['mask'], items['constant_mask'] del items if config['training']['random_crop']: images, masks, constant_mask = random_crop(images, masks, constant_mask, config['training']['strip_size']) images, masks, constant_mask = images.to(device), masks.to(device), constant_mask.to(device) if discriminator: # Forward pass outputs, residuals, gen_loss, dis_adv_loss, logs = inpainting_model.process(images, masks, constant_mask) del masks, constant_mask, residuals loss = gen_loss + dis_adv_loss # Backward pass inpainting_model.backward(gen_loss, dis_adv_loss) else: # Forward pass outputs, residuals, loss, logs = inpainting_model.process(images, masks, constant_mask) del masks, constant_mask, residuals # Backward pass inpainting_model.backward(loss) step = inpainting_model._iteration # Adding losses to Tensorboard for log in logs: logger.add_scalar(log[0], log[1], global_step=step) if i % config['training']['tf_summary_iters'] == 0: grid = torchvision.utils.make_grid(outputs, nrow=4) logger.add_image('outputs', grid, step) grid = torchvision.utils.make_grid(images, nrow=4) logger.add_image('gt', grid, step) del outputs if step % config['training']['save_iters'] == 0: inpainting_model.save() alpha = inpainting_model.alpha inpainting_model.alpha = 0.0 inpainting_model.generator.eval() print('Predicting...') test_loader = test_dataset.create_iterator(batch_size=1) eval_directory = os.path.join(checkpoint, f'predictions/pred_{step}') if not os.path.exists(eval_directory): os.makedirs(eval_directory) # TODO batch size for items in test_loader: images = items['image'].to(device) masks = items['mask'].to(device) constant_mask = items['constant_mask'].to(device) outputs, _, _ = inpainting_model.forward(images, masks, constant_mask) # Batch saving filename = items['filename'] for f, result in zip(filename, outputs): result = result[:, :config['dataset']['image_height'], :config['dataset']['image_width']] save_image(result, os.path.join(eval_directory, f)) del outputs, result, _ mean_psnr, mean_l1, metrics = compute_metrics(eval_directory, config['path']['test']['labels']) logger.add_scalar('PSNR', mean_psnr, global_step=step) logger.add_scalar('L1', mean_l1, global_step=step) inpainting_model.alpha = alpha if step >= config['training']['max_iteration']: break progbar.add(len(images), values=[('iter', step), ('loss', loss.cpu().detach().numpy())] + logs) del images # generator test else: print('\nStart testing...\n') #generator.test() logger.close() print('Done')
def main(pred_path, config_path, images_path, masks_path, checkpoints_path, labels_path, blured, cuda, num_workers, batch_size): from model.net import InpaintingGenerator from utils.general import get_config from utils.progbar import Progbar from data.dataset import Dataset from scripts.metrics import compute_metrics # load config code_path = './' config, pretty_config = get_config(os.path.join(code_path, config_path)) if images_path: config['path']['test']['images'] = images_path if masks_path: config['path']['test']['masks'] = masks_path if cuda: config['gpu'] = cuda config['dataset']['num_workers'] = num_workers print('\nModel configurations:'\ '\n---------------------------------\n'\ + pretty_config +\ '\n---------------------------------\n') os.environ['CUDA_VISIBLE_DEVICES'] = config['gpu'] # Import Torch after os env import torch import torchvision from torch import nn from torch.utils.tensorboard import SummaryWriter from torchvision.utils import save_image, make_grid from torchvision.transforms import ToPILImage # init device if config['gpu'] and torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cudnn.benchmark = True # cudnn auto-tuner else: device = torch.device("cpu") # initialize random seed torch.manual_seed(config["seed"]) torch.cuda.manual_seed_all(config["seed"]) np.random.seed(config["seed"]) random.seed(config["seed"]) # dataset dataset = Dataset(config, training=False) test_loader = dataset.create_iterator(batch_size=batch_size) total = len(dataset) if total == 0: raise Exception("Dataset is empty!") if not os.path.exists(pred_path): os.makedirs(pred_path) # build the model and initialize generator = InpaintingGenerator(config).to(device) generator = nn.DataParallel(generator) checkpoints = os.listdir(checkpoints_path) if len(checkpoints) == 1: checkpoint = os.path.join(checkpoints_path, checkpoints[0]) if config['gpu'] and torch.cuda.is_available(): data = torch.load(checkpoint) else: data = torch.load(checkpoint, map_location=lambda storage, loc: storage) generator.load_state_dict(data['generator'], strict=False) print('Predicting...') generator.eval() progbar = Progbar(total, width=50) for items in test_loader: images = items['image'].to(device) masks = items['mask'].to(device) constant_mask = items['constant_mask'].to(device) bs, c, h, w = images.size() outputs = np.zeros((bs, h, w, c)) # predict if len(checkpoints) > 1: for ch in checkpoints: checkpoint = os.path.join(checkpoints_path, ch) if config['gpu'] and torch.cuda.is_available(): data = torch.load(checkpoint) else: data = torch.load( checkpoint, map_location=lambda storage, loc: storage) generator.load_state_dict(data['generator'], strict=False) generator.eval() for i, result in enumerate( generator.module.predict(images, masks, constant_mask)): grid = make_grid(result, nrow=8, padding=2, pad_value=0, normalize=False, range=None, scale_each=False) result = grid.mul_(255).add_(0.5).clamp_(0, 255).permute( 1, 2, 0).to('cpu', torch.uint8).numpy() outputs[i] += result else: for i, result in enumerate( generator.module.predict(images, masks, constant_mask)): grid = make_grid(result, nrow=8, padding=2, pad_value=0, normalize=False, range=None, scale_each=False) result = grid.mul_(255).add_(0.5).clamp_(0, 255).permute( 1, 2, 0).to('cpu', torch.uint8).numpy() outputs[i] += result outputs = outputs / len(checkpoints) outputs = np.array(outputs, dtype=np.uint8) # Batch saving filename = items['filename'] for f, result in zip(filename, outputs): result = result[:config['dataset']['image_height'], : config['dataset']['image_width']] if blured: test_img = np.array(Image.open(os.path.join(images_path, f))) mask_img = np.array(Image.open(os.path.join(masks_path, f))) mask_img = np.repeat(mask_img[:, :, np.newaxis], 3, axis=2) mask_img = (~np.array(mask_img, dtype=bool)) test_img = test_img * mask_img for i in [3, 5]: result = cv2.blur(result, (i, i)) result = result * (~mask_img) result = test_img + result result = Image.fromarray(result) result.save(os.path.join(pred_path, f)) else: result = Image.fromarray(result) result.save(os.path.join(pred_path, f)) progbar.add(len(images)) if labels_path: compute_metrics(pred_path, labels_path)