Exemple #1
0
def print_evaluation(filename, psnr, ssim, iid=None, n_images=None, time=None):
    from prosr.logger import info
    if iid is not None and n_images:
        msg = '[{:03d}/{:03d}] {:10s} | psnr: {:.2f} | ssim: {:.2f}'.format(
            iid, n_images,
            osp.splitext(filename)[0], psnr, ssim)
    else:
        msg = '{} | psnr: {:.2f} | ssim: {:.2f}'.format(filename, psnr, ssim)

    if time is not None:
        msg += ' | {:.2f} secs'.format(time)
    info(msg)
Exemple #2
0
from argparse import ArgumentParser
import torch
from pprint import pprint
from prosr.logger import info


def parse_args():
    parser = ArgumentParser(
        description='Print configuration file of the pretrained model.')
    parser.add_argument('input', help='path to checkpoint', type=str)

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    # Parse command-line arguments
    args = parse_args()

    params_dict = torch.load(args.input)
    info('Class Name: {}'.format(params_dict['class_name']))
    pprint(params_dict['params'])
Exemple #3
0
if __name__ == '__main__':
    # Parse command-line arguments
    args = parse_args()

    if args.cpu:
        checkpoint = torch.load(args.checkpoint,
                                map_location=lambda storage, loc: storage)
    else:
        checkpoint = torch.load(args.checkpoint)

    cls_model = getattr(prosr.models, checkpoint['class_name'])

    model = cls_model(**checkpoint['params']['G'])
    model.load_state_dict(checkpoint['state_dict'])

    info('phase: {}'.format(Phase.TEST))
    info('checkpoint: {}'.format(osp.basename(args.checkpoint)))

    params = checkpoint['params']
    pprint(params)

    model.eval()

    if torch.cuda.is_available() and not args.cpu:
        model = model.cuda()

    # TODO Change
    dataset = Dataset(Phase.TEST,
                      args.input,
                      args.target,
                      args.scale,
Exemple #4
0
def main(args):
    print('Main functiona start')
    set_seed(args.cmd.seed)

    ############### loading datasets #################
    train_files, test_files = load_dataset(args)

    # reduce validation size for faster training cycles
    if args.test.fast_validation > -1:
        for ft in ['source', 'target']:
            test_files[ft] = test_files[ft][:args.test.fast_validation]

    info('training images = %d' % len(train_files['target']))
    info('validation images = %d' % len(test_files['target']))
    print('Training Images = %d' % len(train_files['target']))
    print('Testing Images = %d' % len(test_files['target']))
    training_dataset = Dataset(prosr.Phase.TRAIN,
                               **train_files,
                               scale=args.data.scale,
                               input_size=args.data.input_size,
                               **args.train.dataset)
    print('Reaching at 131')
    training_data_loader = DataLoader(training_dataset,
                                      batch_size=args.train.batch_size)
    print('Reaching at 134')

    if len(test_files['target']):
        testing_dataset = Dataset(prosr.Phase.VAL,
                                  **test_files,
                                  scale=args.data.scale,
                                  input_size=None,
                                  **args.test.dataset)
        testing_data_loader = DataLoader(testing_dataset, batch_size=1)
    else:
        testing_dataset = None
        testing_data_loader = None

    print('Reaching at 148')

    if args.cmd.no_curriculum or len(args.data.scale) == 1:
        Trainer_cl = SimultaneousMultiscaleTrainer
    else:
        print('Curriculum Training')
        Trainer_cl = CurriculumLearningTrainer

    args.G.max_scale = np.max(args.data.scale)

    trainer = Trainer_cl(args,
                         training_data_loader,
                         save_dir=args.cmd.output,
                         resume_from=args.cmd.checkpoint)

    log_file = os.path.join(args.cmd.output, 'loss_log.txt')

    steps_per_epoch = len(trainer.training_dataset)
    total_steps = trainer.start_epoch * steps_per_epoch

    ############# start training ###############
    print('start training from epoch %d, learning rate %e' %
          (trainer.start_epoch, trainer.lr))
    info('start training from epoch %d, learning rate %e' %
         (trainer.start_epoch, trainer.lr))

    steps_per_epoch = len(trainer.training_dataset)
    errors_accum = defaultdict(list)
    errors_accum_prev = defaultdict(lambda: 0)

    for epoch in range(trainer.start_epoch + 1, args.train.epochs + 1):
        iter_start_time = time()
        trainer.set_train()
        for i, data in enumerate(trainer.training_dataset):
            print('182 %d' % i)
            trainer.set_input(**data)
            trainer.forward()
            trainer.optimize_parameters()

            errors = trainer.get_current_errors()
            for key, item in errors.items():
                errors_accum[key].append(item)

            total_steps += 1
            if total_steps % args.train.io.print_errors_freq == 0:
                print('195 line')
                for key, item in errors.items():
                    if len(errors_accum[key]):
                        errors_accum[key] = np.nanmean(errors_accum[key])
                    if np.isnan(errors_accum[key]):
                        errors_accum[key] = errors_accum_prev[key]
                errors_accum_prev = errors_accum
                t = time() - iter_start_time
                iter_start_time = time()
                print_current_errors(epoch,
                                     total_steps,
                                     errors_accum,
                                     t,
                                     log_name=log_file)
                print('206 line')

                if args.cmd.visdom:
                    lrs = {
                        'lr%d' % i: param_group['lr']
                        for i, param_group in enumerate(
                            trainer.optimizer_G.param_groups)
                    }
                    real_epoch = float(total_steps) / steps_per_epoch
                    visualizer.display_current_results(
                        trainer.get_current_visuals(), real_epoch)
                    visualizer.plot(errors_accum, real_epoch, 'loss')
                    visualizer.plot(lrs, real_epoch, 'lr rate', 'lr')
                print('219 line')

                errors_accum = defaultdict(list)

        # Save model
        if epoch % args.train.io.save_model_freq == 0:
            print('save the model')
            info('saving the model at the end of epoch %d, iters %d' %
                 (epoch, total_steps),
                 bold=True)
            trainer.save(str(epoch), epoch, trainer.lr)
            print('saved model')

        ################# update learning rate  #################
        print('Update learning rate')
        if (epoch - trainer.best_epoch) > args.train.lr_schedule_patience:
            trainer.save('last_lr_%g' % trainer.lr, epoch, trainer.lr)
            trainer.update_learning_rate()

        # eval epochs incrementally
        eval_epoch_freq = 1

        ################# test with validation set ##############
        if testing_data_loader and epoch % eval_epoch_freq == 0:
            eval_epoch_freq = min(eval_epoch_freq * 2,
                                  args.train.io.eval_epoch_freq)
            with torch.no_grad():
                test_start_time = time()
                # use validation set
                trainer.set_eval()
                trainer.reset_eval_result()
                for i, data in enumerate(testing_data_loader):
                    trainer.set_input(**data)
                    trainer.evaluate()

                t = time() - test_start_time
                test_result = trainer.get_current_eval_result()

                ################ visualize ###############
                if args.cmd.visdom:
                    visualizer.plot(test_result,
                                    float(total_steps) / steps_per_epoch,
                                    'eval', 'psnr')

                trainer.update_best_eval_result(epoch, test_result)
                info('eval at epoch %d : ' % epoch + ' | '.join([
                    '{}: {:.02f}'.format(k, v) for k, v in test_result.items()
                ]) + ' | time {:d} sec'.format(int(t)),
                     bold=True)
                print('eval at epoch 268')

                info('best so far %d : ' % trainer.best_epoch + ' | '.join([
                    '{}: {:.02f}'.format(k, v)
                    for k, v in trainer.best_eval.items()
                ]),
                     bold=True)
                print('best so far 276')
                if trainer.best_epoch == epoch:
                    if len(trainer.best_eval) > 1:
                        if not isinstance(trainer, CurriculumLearningTrainer):
                            best_key = [
                                k for k in trainer.best_eval
                                if trainer.best_eval[k] == test_result[k]
                            ]
                        else:
                            # select only upto current training scale
                            best_key = [
                                "psnr_x%d" % trainer.opt.data.scale[s_idx]
                                for s_idx in range(trainer.current_scale_idx +
                                                   1)
                            ]
                            best_key = [
                                k for k in best_key
                                if trainer.best_eval[k] == test_result[k]
                            ]

                    else:
                        best_key = list(trainer.best_eval.keys())
                    trainer.save(
                        str(epoch) + '_best_' + '_'.join(best_key), epoch,
                        trainer.lr)
Exemple #5
0
                sys.exit(0)
    elif args.model is not None:
        params = edict(getattr(prosr, args.model + '_params'))

    else:
        params = torch.load(args.checkpoint + '_net_G.pth')['params']

    # parameters overring
    if args.fast_validation is not None:
        params.test.fast_validation = args.fast_validation
    del args.fast_validation

    # Add command line arguments
    params.cmd = edict(vars(args))

    pprint(params)

    if not osp.isdir(args.output):
        os.makedirs(args.output)
    np.save(osp.join(args.output, 'params'), params)

    experiment_id = osp.basename(args.output)

    info('experiment ID: {}'.format(experiment_id))

    if args.visdom:
        from prosr.visualizer import Visualizer
        visualizer = Visualizer(experiment_id, port=args.visdom_port)

    main(params)
Exemple #6
0
def main(args):
    set_seed(args.cmd.seed)

    ############### loading datasets #################
    train_files, test_files = load_dataset(args)

    # reduce validation size for faster training cycles
    if args.test.fast_validation > -1:
        for ft in ['source', 'target']:
            test_files[ft] = test_files[ft][:args.test.fast_validation]

    info('training images = %d' % len(train_files['target']))
    info('validation images = %d' % len(test_files['target']))

    training_dataset = Dataset(prosr.Phase.TRAIN,
                               **train_files,
                               scale=args.data.scale,
                               input_size=args.data.input_size,
                               **args.train.dataset)

    training_data_loader = DataLoader(training_dataset,
                                      batch_size=args.train.batch_size)

    if len(test_files['target']):
        testing_dataset = Dataset(prosr.Phase.VAL,
                                  **test_files,
                                  scale=args.data.scale,
                                  input_size=None,
                                  **args.test.dataset)
        testing_data_loader = DataLoader(testing_dataset, batch_size=1)
    else:
        testing_dataset = None
        testing_data_loader = None

    if args.cmd.no_curriculum or len(args.data.scale) == 1:
        Trainer_cl = SimultaneousMultiscaleTrainer
    else:
        Trainer_cl = CurriculumLearningTrainer

    args.G.max_scale = np.max(args.data.scale)

    trainer = Trainer_cl(args,
                         training_data_loader,
                         save_dir=args.cmd.output,
                         resume_from=args.cmd.checkpoint)

    log_file = os.path.join(args.cmd.output, 'loss_log.txt')

    steps_per_epoch = len(trainer.training_dataset)
    total_steps = trainer.start_epoch * steps_per_epoch

    ############# start training ###############
    info('start training from epoch %d, learning rate %e' %
         (trainer.start_epoch, trainer.lr))

    steps_per_epoch = len(trainer.training_dataset)
    errors_accum = defaultdict(list)
    errors_accum_prev = defaultdict(lambda: 0)

    # eval epochs incrementally
    eval_epoch_freq = 4
    batchsize = (int)(800 / len(trainer.training_dataset))
    print("Batch size = ", batchsize)
    loss = []
    psnr_list = []
    output_imgs = torch.zeros(
        (len(trainer.training_dataset) * batchsize, 3, 32, 32))

    #########################################################################
    for epoch in range(trainer.start_epoch + 1, args.train.epochs + 1):
        iter_start_time = time()
        epoch_start_time = time()
        trainer.set_train()
        epoch_loss = 0
        print("Epoch: ", epoch)
        # total_epoch_error = 0
        for i, data in enumerate(trainer.training_dataset):
            # data is a dictionary. See trainer.set_input() function for info
            # print("Batch", i)

            ##################################################################
            # Forward and backward pass
            trainer.set_input(**data)
            output_batch = trainer.forward()
            l1_loss = trainer.optimize_parameters()
            epoch_loss += l1_loss
            total_steps += 1

            ##################################################################
            # Save output images
            if (epoch % args.train.io.save_img_freq == 0
                    or epoch == args.train.epochs):
                # for ind in range(output_batch.shape[0]):
                #     output_np = trainer.tensor2imMine(output_batch[ind].detach())
                #     output_imgs[i*batchsize + ind] = output_np
                output_imgs[i * batchsize:(i + 1) * batchsize] = output_batch

            ##################################################################
            # # Collect and print Errors (Unnecessary)
            # errors = trainer.get_current_errors()
            # for key, item in errors.items():
            #     errors_accum[key].append(item)
            #
            # if total_steps % args.train.io.print_errors_freq == 0:
            #     for key, item in errors.items():
            #         if len(errors_accum[key]):
            #             errors_accum[key] = np.nanmean(errors_accum[key])
            #         if np.isnan(errors_accum[key]):
            #             errors_accum[key] = errors_accum_prev[key]
            #     errors_accum_prev = errors_accum
            #     t = time() - iter_start_time
            #     iter_start_time = time()
            #     print_current_errors(
            #         epoch, total_steps, errors_accum, t, log_name=log_file)
            #
            #     if args.cmd.visdom:
            #         lrs = {
            #             'lr%d' % i: param_group['lr']
            #             for i, param_group in enumerate(
            #                 trainer.optimizer_G.param_groups)
            #         }
            #         real_epoch = float(total_steps) / steps_per_epoch
            #         visualizer.display_current_results(
            #             trainer.get_current_visuals(), real_epoch)
            #         visualizer.plot(errors_accum, real_epoch, 'loss')
            #         visualizer.plot(lrs, real_epoch, 'lr rate', 'lr')
            #
            #     errors_accum = defaultdict(list)

        ##################################################################
        # print loss and epoch time
        epoch_time = time() - epoch_start_time
        print("Epoch time = ", epoch_time)
        print("Epoch Loss per sample = ", epoch_loss / batchsize)
        loss.append(epoch_loss / batchsize)

        ##################################################################
        # Plot loss
        loss_file = 'l1_loss_plot.png'
        if (epoch % 10 == 0):
            plt.plot(loss, label='Training Loss')
            plt.xlabel('Epoch')
            plt.ylabel('L1 Loss')
            plt.legend()
            plt.savefig(loss_file)
            plt.close()

        ##################################################################
        # Save intermediate and final SR images
        # if(epoch % save_img_frequency == 0 or epoch == args.train.epochs):
        #     image_dir = "./outputs/output_images/"
        #     trainer.make_dir(image_dir)
        #     for i in range(output_imgs.shape[0]):
        #         image_name = '%d.png' % (i + 1)
        #         save_path = os.path.join(image_dir, image_name)
        #         image_pil = Image.fromarray(output_imgs[i].astype(np.uint8), mode='RGB')
        #         image_pil.save(save_path)
        # io.imsave(save_path, output_img)

        if (epoch % args.train.io.save_img_freq == 0
                or epoch == args.train.epochs):
            image_dir = "./outputs/output_images/"
            trainer.make_dir(image_dir)
            save_images_from_tensors(args, output_imgs, image_dir)

        ##################################################################
        # Save model
        if epoch % args.train.io.save_model_freq == 0:
            info('saving the model at the end of epoch %d, iters %d' %
                 (epoch, total_steps),
                 bold=True)
            trainer.save(str(epoch), epoch, trainer.lr)

        ##################################################################
        # update learning rate
        if (epoch - trainer.best_epoch) > args.train.lr_schedule_patience:
            trainer.save('last_lr_%g' % trainer.lr, epoch, trainer.lr)
            trainer.update_learning_rate()

        ##################################################################
        # test with validation set, PSNR calculation
        if testing_data_loader and (
            ((epoch) % args.train.io.eval_epoch_freq == 0
             or epoch == args.train.epochs) or epoch == 1):
            # eval_epoch_freq = min(eval_epoch_freq * 2, args.train.io.eval_epoch_freq)
            with torch.no_grad():
                test_start_time = time()
                # use validation set
                trainer.set_eval()
                trainer.reset_eval_result()
                for i, data in enumerate(testing_data_loader):
                    trainer.set_input(**data)
                    trainer.evaluate()

                t = time() - test_start_time
                test_result = trainer.get_current_eval_result()

                ################ visualize ###############
                if args.cmd.visdom:
                    visualizer.plot(test_result,
                                    float(total_steps) / steps_per_epoch,
                                    'eval', 'psnr')

                trainer.update_best_eval_result(epoch, test_result)
                info('eval at epoch %d : ' % epoch + ' | '.join([
                    '{}: {:.02f}'.format(k, v) for k, v in test_result.items()
                ]) + ' | time {:d} sec'.format(int(t)),
                     bold=True)
                for k, v in test_result.items():
                    psnr_list.append(v)

                info('best so far %d : ' % trainer.best_epoch + ' | '.join([
                    '{}: {:.02f}'.format(k, v)
                    for k, v in trainer.best_eval.items()
                ]),
                     bold=True)

                if trainer.best_epoch == epoch:
                    if len(trainer.best_eval) > 1:
                        if not isinstance(trainer, CurriculumLearningTrainer):
                            best_key = [
                                k for k in trainer.best_eval
                                if trainer.best_eval[k] == test_result[k]
                            ]
                        else:
                            # select only upto current training scale
                            best_key = [
                                "psnr_x%d" % trainer.opt.data.scale[s_idx]
                                for s_idx in range(trainer.current_scale_idx +
                                                   1)
                            ]
                            best_key = [
                                k for k in best_key
                                if trainer.best_eval[k] == test_result[k]
                            ]

                    else:
                        best_key = list(trainer.best_eval.keys())
                    trainer.save(
                        str(epoch) + '_best_' + '_'.join(best_key), epoch,
                        trainer.lr)

            plot_psnr(args, psnr_list)

    ##################################################################
    # Plot final loss
    loss_file = 'l1_loss_plot.png'
    plt.plot(loss, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('L1 Loss')
    plt.legend()
    plt.savefig(loss_file)
    plt.close()
Exemple #7
0
def change(img_name,img_path):
    model_path = os.path.join(MODEL_ROOT, "proSR_x2.pth")
    input_path = [img_path +'/'+ img_name]
    target_path = []
    scale = [2, 4, 8]
    scale_idx = 0
    downscale = False
    output_dir = OUTPUT_ROOT
    max_dimension = 0
    padding = 0
    useCPU = False
    # cuda
    checkpoint = torch.load(model_path)
    cls_model = getattr(prosr.models, checkpoint['class_name'])
    model = cls_model(**checkpoint['params']['G'])
    model.load_state_dict(checkpoint['state_dict'])

    # model.load_state_ 모델 데이터 로드하기

    info('phase: {}'.format(Phase.TEST))
    info('checkpoint: {}'.format(osp.basename(model_path)))
    params = checkpoint['params']
    pprint(params)

    model.eval()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # 여기서부터 이제 받아온 데이터로 셋팅
    dataset = Dataset(
        Phase.TEST,
        input_path,
        target_path,
        scale[scale_idx],
        input_size=None,
        mean=params['train']['dataset']['mean'],
        stddev=params['train']['dataset']['stddev'],
        downscale=downscale)

    data_loader = DataLoader(dataset, batch_size=1)

    mean = params['train']['dataset']['mean']
    stddev = params['train']['dataset']['stddev']

    if not osp.isdir(output_dir):
        os.makedirs(output_dir)
    info('Saving images in: {}'.format(output_dir))

    with torch.no_grad():
        if len(target_path):
            psnr_mean = 0
            ssim_mean = 0

        for iid, data in enumerate(data_loader):
            tic = time.time()
            # split image in chuncks of max-dimension
            if max_dimension:
                data_chunks = DataChunks({'input': data['input']}, max_dimension, padding, scale[scale_idx])

                for chunk in data_chunks.iter():
                    input = chunk['input']
                    if not useCPU:
                        input = input.cuda()
                    output = model(input, scale[scale_idx])
                    data_chunks.gather(output)
                output = data_chunks.concatenate() + data['bicubic']
            else:
                input = data['input']
                print("input: ", data['input'])
                if not useCPU:
                    input = input.cuda()
                output = model(input, scale[scale_idx]).cpu() + data['bicubic']
            sr_img = tensor2im(output, mean, stddev)
            toc = time.time()
            if 'target' in data:
                hr_img = tensor2im(data['target'], mean, stddev)
                psnr_val, ssim_val = eval_psnr_and_ssim(
                    sr_img, hr_img, scale[scale_idx])
                print_evaluation(
                    osp.basename(data['input_fn'][0]), psnr_val, ssim_val,
                    iid + 1, len(dataset), toc - tic)
                psnr_mean += psnr_val
                ssim_mean += ssim_val
            else:
                print_evaluation(
                    osp.basename(data['input_fn'][0]), np.nan, np.nan, iid + 1,
                    len(dataset), toc - tic)

            # 출력
            fn = osp.join(output_dir, 'result_'+osp.basename(data['input_fn'][0]))
            io.imsave(fn, sr_img)
            # ir = io.imread(fn)
            # w , h, s = ir.shape
            # nw=int(w/2)
            # nh=int(h/2)
            # resize_img = cv2.resize(ir, (0, 0),fx=0.5,fy=0.5, interpolation=cv2.INTER_AREA)
            # io.imsave(fn,resize_img)


        if len(target_path):
            psnr_mean /= len(dataset)
            ssim_mean /= len(dataset)
            print_evaluation("average", psnr_mean, ssim_mean)
        return 'result_'+osp.basename(data['input_fn'][0])
# if __name__ == '__main__':
#     change('waterfall.jpg')
Exemple #8
0
    args = parse_args()
    params = config.defaults

    pprint(params)

    net_G = ProSR(params.G, params.scale).cuda()
    net_G.load_state_dict(torch.load(args.weights))

    net_G.eval()

    # DIV2K dataset statistics
    mean = params.mean_img
    std = ([1.0 / params.mul_img] * 3)

    # Set of image transformations applied to the input image
    info('Loading ProSR')
    preprocess = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    # Prepare output folder
    if args.output_dir is not None and \
        not osp.isdir(args.output_dir):
        os.makedirs(args.output_dir)

    info('Processing images:')
    for fn_lr in args.input[:3]:
        lr = io.imread(fn_lr)

        with torch.no_grad():
            lr_t = Variable(preprocess(lr)[None, ...])