Пример #1
0
def main(config_path, experiment_path):
    # ARGS 
    masks_path = None
    training = True

    # load config
    code_path = '/'
    config, pretty_config = get_config(os.path.join(code_path, config_path))
    config['path']['experiment'] = os.path.join(experiment_path, config['path']['experiment'])

    print('\nModel configurations:'\
          '\n---------------------------------\n'\
          + pretty_config +\
          '\n---------------------------------\n')

    os.environ['CUDA_VISIBLE_DEVICES'] = config['gpu']

    # Import Torch after os env
    import torch
    import torchvision
    from torch import nn
    from torch.utils.tensorboard import SummaryWriter
    from torchvision.utils import save_image

    # init device
    if config['gpu'] and torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True   # cudnn auto-tuner
    else:
        device = torch.device("cpu")

    # initialize random seed
    torch.manual_seed(config["seed"])
    torch.cuda.manual_seed_all(config["seed"])
    if not training:
        np.random.seed(config["seed"])
        random.seed(config["seed"])

    # parse args
    images_path = config['path']['train']
    checkpoint = config['path']['experiment']
    discriminator = config['training']['discriminator']

    # initialize log writer
    logger = SummaryWriter(log_dir=config['path']['experiment'])

    # build the model and initialize
    inpainting_model = InpaintingModel(config).to(device)
    if checkpoint:
        inpainting_model.load()
    
    pred_directory = os.path.join(checkpoint, 'predictions')
    if not os.path.exists(pred_directory):
        os.makedirs(pred_directory)

    # generator training
    if training:
        print('\nStart training...\n')
        batch_size = config['training']['batch_size']

        # create dataset
        dataset = Dataset(config, training=True)
        train_loader = dataset.create_iterator(batch_size)

        test_dataset = Dataset(config, training=False)

        # Train the generator
        total = len(dataset)
        if total == 0:
            raise Exception("Dataset is empty!")

        # Training loop
        epoch = 0
        for i, items in enumerate(train_loader):
            inpainting_model.train()

            if i % total == 0:
                epoch += 1
                print('Epoch', epoch)
                progbar = Progbar(total, width=20, stateful_metrics=['iter'])
            
            images, masks, constant_mask = items['image'], items['mask'], items['constant_mask']

            del items
            if config['training']['random_crop']:
                images, masks, constant_mask = random_crop(images, masks, constant_mask, 
                                                           config['training']['strip_size'])
            images, masks, constant_mask = images.to(device), masks.to(device), constant_mask.to(device)

            if discriminator:
                # Forward pass
                outputs, residuals, gen_loss, dis_adv_loss, logs = inpainting_model.process(images, masks, constant_mask)    
                del masks, constant_mask, residuals
                loss = gen_loss + dis_adv_loss
                # Backward pass
                inpainting_model.backward(gen_loss, dis_adv_loss)
            else:
                # Forward pass
                outputs, residuals, loss, logs = inpainting_model.process(images, masks, constant_mask)    
                del masks, constant_mask, residuals
                # Backward pass
                inpainting_model.backward(loss)
            
            step = inpainting_model._iteration
            
            # Adding losses to Tensorboard
            for log in logs:
                logger.add_scalar(log[0], log[1], global_step=step)

            if i % config['training']['tf_summary_iters'] == 0:
                grid = torchvision.utils.make_grid(outputs, nrow=4)
                logger.add_image('outputs', grid, step)

                grid = torchvision.utils.make_grid(images, nrow=4)
                logger.add_image('gt', grid, step)
            
            del outputs
            if step % config['training']['save_iters'] == 0:
                inpainting_model.save()
                
                alpha = inpainting_model.alpha
                inpainting_model.alpha = 0.0
                
                inpainting_model.generator.eval()

                print('Predicting...')
                test_loader = test_dataset.create_iterator(batch_size=1)    
                
                eval_directory = os.path.join(checkpoint, f'predictions/pred_{step}') 
                if not os.path.exists(eval_directory):
                    os.makedirs(eval_directory)
                
                # TODO batch size
                for items in test_loader:
                    images = items['image'].to(device)
                    masks = items['mask'].to(device)
                    constant_mask = items['constant_mask'].to(device)
                    outputs, _, _ = inpainting_model.forward(images, masks, constant_mask)

                    # Batch saving
                    filename = items['filename']
                    for f, result in zip(filename, outputs): 
                        result = result[:, :config['dataset']['image_height'], :config['dataset']['image_width']]
                        save_image(result, os.path.join(eval_directory, f))
                    del outputs, result, _

                mean_psnr, mean_l1, metrics = compute_metrics(eval_directory, config['path']['test']['labels'])
                logger.add_scalar('PSNR', mean_psnr, global_step=step)
                logger.add_scalar('L1', mean_l1, global_step=step)

                inpainting_model.alpha = alpha
            
            if step >= config['training']['max_iteration']:
                break

            progbar.add(len(images), values=[('iter', step), 
                                             ('loss', loss.cpu().detach().numpy())] + logs)
            del images

    # generator test
    else:
        print('\nStart testing...\n')
        #generator.test()

    logger.close()
    print('Done')
Пример #2
0
def main(pred_path, config_path, images_path, masks_path, checkpoints_path,
         labels_path, blured, cuda, num_workers, batch_size):

    from model.net import InpaintingGenerator
    from utils.general import get_config
    from utils.progbar import Progbar
    from data.dataset import Dataset
    from scripts.metrics import compute_metrics

    # load config
    code_path = './'
    config, pretty_config = get_config(os.path.join(code_path, config_path))

    if images_path:
        config['path']['test']['images'] = images_path
    if masks_path:
        config['path']['test']['masks'] = masks_path
    if cuda:
        config['gpu'] = cuda
    config['dataset']['num_workers'] = num_workers

    print('\nModel configurations:'\
          '\n---------------------------------\n'\
          + pretty_config +\
          '\n---------------------------------\n')

    os.environ['CUDA_VISIBLE_DEVICES'] = config['gpu']

    # Import Torch after os env
    import torch
    import torchvision
    from torch import nn
    from torch.utils.tensorboard import SummaryWriter
    from torchvision.utils import save_image, make_grid
    from torchvision.transforms import ToPILImage

    # init device
    if config['gpu'] and torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True  # cudnn auto-tuner
    else:
        device = torch.device("cpu")

    # initialize random seed
    torch.manual_seed(config["seed"])
    torch.cuda.manual_seed_all(config["seed"])
    np.random.seed(config["seed"])
    random.seed(config["seed"])

    # dataset
    dataset = Dataset(config, training=False)
    test_loader = dataset.create_iterator(batch_size=batch_size)

    total = len(dataset)
    if total == 0:
        raise Exception("Dataset is empty!")

    if not os.path.exists(pred_path):
        os.makedirs(pred_path)

    # build the model and initialize
    generator = InpaintingGenerator(config).to(device)
    generator = nn.DataParallel(generator)

    checkpoints = os.listdir(checkpoints_path)
    if len(checkpoints) == 1:
        checkpoint = os.path.join(checkpoints_path, checkpoints[0])
        if config['gpu'] and torch.cuda.is_available():
            data = torch.load(checkpoint)
        else:
            data = torch.load(checkpoint,
                              map_location=lambda storage, loc: storage)

        generator.load_state_dict(data['generator'], strict=False)

    print('Predicting...')
    generator.eval()

    progbar = Progbar(total, width=50)
    for items in test_loader:
        images = items['image'].to(device)
        masks = items['mask'].to(device)
        constant_mask = items['constant_mask'].to(device)

        bs, c, h, w = images.size()
        outputs = np.zeros((bs, h, w, c))

        # predict
        if len(checkpoints) > 1:
            for ch in checkpoints:
                checkpoint = os.path.join(checkpoints_path, ch)
                if config['gpu'] and torch.cuda.is_available():
                    data = torch.load(checkpoint)
                else:
                    data = torch.load(
                        checkpoint, map_location=lambda storage, loc: storage)

                generator.load_state_dict(data['generator'], strict=False)
                generator.eval()

                for i, result in enumerate(
                        generator.module.predict(images, masks,
                                                 constant_mask)):
                    grid = make_grid(result,
                                     nrow=8,
                                     padding=2,
                                     pad_value=0,
                                     normalize=False,
                                     range=None,
                                     scale_each=False)
                    result = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(
                        1, 2, 0).to('cpu', torch.uint8).numpy()
                    outputs[i] += result
        else:
            for i, result in enumerate(
                    generator.module.predict(images, masks, constant_mask)):
                grid = make_grid(result,
                                 nrow=8,
                                 padding=2,
                                 pad_value=0,
                                 normalize=False,
                                 range=None,
                                 scale_each=False)
                result = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(
                    1, 2, 0).to('cpu', torch.uint8).numpy()
                outputs[i] += result

        outputs = outputs / len(checkpoints)
        outputs = np.array(outputs, dtype=np.uint8)

        # Batch saving
        filename = items['filename']
        for f, result in zip(filename, outputs):
            result = result[:config['dataset']['image_height'], :
                            config['dataset']['image_width']]

            if blured:
                test_img = np.array(Image.open(os.path.join(images_path, f)))

                mask_img = np.array(Image.open(os.path.join(masks_path, f)))
                mask_img = np.repeat(mask_img[:, :, np.newaxis], 3, axis=2)
                mask_img = (~np.array(mask_img, dtype=bool))

                test_img = test_img * mask_img
                for i in [3, 5]:
                    result = cv2.blur(result, (i, i))

                result = result * (~mask_img)

                result = test_img + result
                result = Image.fromarray(result)
                result.save(os.path.join(pred_path, f))
            else:
                result = Image.fromarray(result)
                result.save(os.path.join(pred_path, f))

        progbar.add(len(images))

    if labels_path:
        compute_metrics(pred_path, labels_path)