Пример #1
0
def test(args):
    # device
    device = torch.device("cuda:%d" %
                          args.gpu if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True

    # data
    testset = SonyTestDataset(args.input_dir, args.gt_dir)
    test_loader = DataLoader(testset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)

    # model
    model = Unet()
    model.load_state_dict(torch.load(args.model))
    model.to(device)
    model.eval()

    # testing
    for i, databatch in tqdm(enumerate(test_loader), total=len(test_loader)):
        input_full, scale_full, gt_full, test_id, ratio = databatch
        scale_full, gt_full = torch.squeeze(scale_full), torch.squeeze(gt_full)

        # processing
        inputs = input_full.to(device)
        outputs = model(inputs)
        outputs = outputs.cpu().detach()
        outputs = torch.squeeze(outputs)
        outputs = outputs.permute(1, 2, 0)

        # scaling can clipping
        outputs, scale_full, gt_full = outputs.numpy(), scale_full.numpy(
        ), gt_full.numpy()
        scale_full = scale_full * np.mean(gt_full) / np.mean(
            scale_full
        )  # scale the low-light image to the same mean of the ground truth
        outputs = np.minimum(np.maximum(outputs, 0), 1)

        # saving
        if not os.path.isdir(os.path.join(args.result_dir, 'eval')):
            os.makedirs(os.path.join(args.result_dir, 'eval'))
        scipy.misc.toimage(
            scale_full * 255, high=255, low=0, cmin=0, cmax=255).save(
                os.path.join(
                    args.result_dir, 'eval',
                    '%05d_00_train_%d_scale.jpg' % (test_id[0], ratio[0])))
        scipy.misc.toimage(
            outputs * 255, high=255, low=0, cmin=0, cmax=255).save(
                os.path.join(
                    args.result_dir, 'eval',
                    '%05d_00_train_%d_out.jpg' % (test_id[0], ratio[0])))
        scipy.misc.toimage(
            gt_full * 255, high=255, low=0, cmin=0, cmax=255).save(
                os.path.join(
                    args.result_dir, 'eval',
                    '%05d_00_train_%d_gt.jpg' % (test_id[0], ratio[0])))
Пример #2
0
def test(args):
    # device
    device = torch.device("cuda:%d" %
                          args.gpu if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True

    # images path
    fns = glob.glob(path.join(args.imgdir, '*.DNG'))
    n = len(fns)

    # model
    model = Unet()
    model.load_state_dict(torch.load(args.model))
    model.to(device)
    model.eval()

    # ratio
    ratio = 200

    for idx in range(n):
        fn = fns[idx]
        print(fn)
        raw = rawpy.imread(fn)

        input = np.expand_dims(pack_raw(raw), axis=0) * ratio
        scale_full = np.expand_dims(np.float32(input / 65535.0), axis=0)
        input = crop_center(input, 1024, 1024)
        input = torch.from_numpy(input)
        input = torch.squeeze(input)
        input = input.permute(2, 0, 1)
        input = torch.unsqueeze(input, dim=0)
        input = input.to(device)

        outputs = model(input)
        outputs = outputs.cpu().detach()
        outputs = torch.squeeze(outputs)
        outputs = outputs.permute(1, 2, 0)

        outputs = outputs.numpy()
        outputs = np.minimum(np.maximum(outputs, 0), 1)

        scale_full = torch.from_numpy(scale_full)
        scale_full = torch.squeeze(scale_full)

        scipy.misc.toimage(outputs * 255, high=255, low=0, cmin=0,
                           cmax=255).save(
                               path.join(args.imgdir,
                                         path.basename(fn) + '_out.jpg'))
def main(args):

    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
    device = torch.device(
        'cuda' if args.use_cuda and torch.cuda.is_available() else 'cpu')

    # Create output folder
    if (args.output_folder is not None):
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
            logging.debug('Creating folder `{0}`'.format(args.output_folder))

        output_folder = os.path.join(args.output_folder,
                                     time.strftime('%Y-%m-%d_%H%M%S'))
        os.makedirs(output_folder)
        logging.debug('Creating folder `{0}`'.format(output_folder))

        args.datafolder = os.path.abspath(args.datafolder)
        args.model_path = os.path.abspath(
            os.path.join(output_folder, 'model.th'))

        # Save the configuration in a config.json file
        with open(os.path.join(output_folder, 'config.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)
        logging.info('Saving configuration file in `{0}`'.format(
            os.path.abspath(os.path.join(output_folder, 'config.json'))))

    # Get datasets and load into meta learning format
    meta_train_dataset, meta_val_dataset, _ = get_datasets(
        args.dataset,
        args.datafolder,
        args.num_ways,
        args.num_shots,
        args.num_shots_test,
        augment=augment,
        fold=args.fold,
        download=download_data)

    meta_train_dataloader = BatchMetaDataLoader(meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)

    meta_val_dataloader = BatchMetaDataLoader(meta_val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    # Define model
    model = Unet(device=device, feature_scale=args.feature_scale)
    model = model.to(device)
    print(f'Using device: {device}')

    # Define optimizer
    meta_optimizer = torch.optim.Adam(model.parameters(),
                                      lr=args.meta_lr)  #, weight_decay=1e-5)
    #meta_optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, momentum = 0.99)

    # Define meta learner
    metalearner = ModelAgnosticMetaLearning(
        model,
        meta_optimizer,
        first_order=args.first_order,
        num_adaptation_steps=args.num_adaption_steps,
        step_size=args.step_size,
        learn_step_size=False,
        loss_function=loss_function,
        device=device)

    best_value = None

    # Training loop
    epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 +
                                             int(math.log10(args.num_epochs)))
    train_losses = []
    val_losses = []
    train_ious = []
    train_accuracies = []
    val_accuracies = []
    val_ious = []

    start_time = time.time()

    for epoch in range(args.num_epochs):
        print('start epoch ', epoch + 1)
        print('start train---------------------------------------------------')
        train_loss, train_accuracy, train_iou = metalearner.train(
            meta_train_dataloader,
            max_batches=args.num_batches,
            verbose=args.verbose,
            desc='Training',
            leave=False)
        print(f'\n train accuracy: {train_accuracy}, train loss: {train_loss}')
        print('end train---------------------------------------------------')
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        train_ious.append(train_iou)

        # Evaluate in given intervals
        if epoch % args.val_step_size == 0:
            print(
                'start evaluate-------------------------------------------------'
            )
            results = metalearner.evaluate(meta_val_dataloader,
                                           max_batches=args.num_batches,
                                           verbose=args.verbose,
                                           desc=epoch_desc.format(epoch + 1),
                                           is_test=False)
            val_acc = results['accuracy']
            val_loss = results['mean_outer_loss']
            val_losses.append(val_loss)
            val_accuracies.append(val_acc)
            val_ious.append(results['iou'])
            print(
                f'\n validation accuracy: {val_acc}, validation loss: {val_loss}'
            )
            print(
                'end evaluate-------------------------------------------------'
            )

            # Save best model
            if 'accuracies_after' in results:
                if (best_value is None) or (best_value <
                                            results['accuracies_after']):
                    best_value = results['accuracies_after']
                    save_model = True
            elif (best_value is None) or (best_value >
                                          results['mean_outer_loss']):
                best_value = results['mean_outer_loss']
                save_model = True
            else:
                save_model = False

            if save_model and (args.output_folder is not None):
                with open(args.model_path, 'wb') as f:
                    torch.save(model.state_dict(), f)

        print('end epoch ', epoch + 1)

    elapsed_time = time.time() - start_time
    print('Finished after ',
          time.strftime('%H:%M:%S', time.gmtime(elapsed_time)))

    r = {}
    r['train_losses'] = train_losses
    r['train_accuracies'] = train_accuracies
    r['train_ious'] = train_ious
    r['val_losses'] = val_losses
    r['val_accuracies'] = val_accuracies
    r['val_ious'] = val_ious
    r['time'] = time.strftime('%H:%M:%S', time.gmtime(elapsed_time))
    with open(os.path.join(output_folder, 'train_results.json'), 'w') as g:
        json.dump(r, g)
        logging.info('Saving results dict in `{0}`'.format(
            os.path.abspath(os.path.join(output_folder,
                                         'train_results.json'))))

    # Plot results
    plot_errors(args.num_epochs,
                train_losses,
                val_losses,
                val_step_size=args.val_step_size,
                output_folder=output_folder,
                save=True,
                bce_dice_focal=bce_dice_focal)
    plot_accuracy(args.num_epochs,
                  train_accuracies,
                  val_accuracies,
                  val_step_size=args.val_step_size,
                  output_folder=output_folder,
                  save=True)
    plot_iou(args.num_epochs,
             train_ious,
             val_ious,
             val_step_size=args.val_step_size,
             output_folder=output_folder,
             save=True)

    if hasattr(meta_train_dataset, 'close'):
        meta_train_dataset.close()
        meta_val_dataset.close()