def main():
    display_config()

    dataset_root = get_full_path(args.scale, args.train_set)

    print('Contructing dataset...')
    dataset_factory = DatasetFactory()
    train_dataset = dataset_factory.create_dataset(args.model, dataset_root)

    model_factory = ModelFactory()
    model = model_factory.create_model(args.model)

    loss_fn = get_loss_fn(model.name)

    check_point = os.path.join('check_point', model.name,
                               str(args.scale) + 'x')

    solver = Solver(model,
                    check_point,
                    loss_fn=loss_fn,
                    batch_size=args.batch_size,
                    num_epochs=args.num_epochs,
                    learning_rate=args.learning_rate,
                    fine_tune=args.fine_tune,
                    verbose=args.verbose)

    print('Training...')
    solver.train(train_dataset)
def main():
    display_config()
    useGPU = True

    print('Contructing dataset...')
    dataset_root = os.path.join('preprocessed_data', args.test_set)
    dataset_factory = DatasetFactory()
    val_dataset = dataset_factory.create_dataset('VALID', dataset_root)

    print('Loading model...')
    model_path = os.path.join('check_point', args.model, str(args.scale) + 'x', 'model.pt')
    if not os.path.exists(model_path):
        raise Exception('Cannot find %s.' % model_path)
    model = torch.load(model_path)
    if useGPU:
        model = model.cuda()
    else:
        model = model.cpu()

    # needed for forward pass! save memory
    for param in model.parameters():
        param.requires_grad = False

    dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=6)

    print('Testing...')
    avrtime = 0
    cnt = 0
    for i, (input_batch, imgname) in enumerate(dataloader):
        if useGPU:
            input_batch = Variable(input_batch.cuda(), requires_grad=False)
        else:
            input_batch = Variable(input_batch, requires_grad=False)

        start = time.time()
        output_batch, _ = model(input_batch)
        elapsed_time = time.time() - start
        avrtime += elapsed_time

        output_batch = (output_batch.data + 0.5) * 255  # change into pixel domain
        output_batch = output_batch.cpu().numpy()
        export(args.scale, model.name, output_batch, imgname[0])
        cnt += 1

    avrtime = avrtime / cnt
    print('Average time: %f\n' % avrtime)
    print('..Finish..')
示例#3
0
def main():
    display_config()

    dataset_root = get_full_path(args.scale, args.test_set)

    print('Contructing dataset...')
    dataset_factory = DatasetFactory()
    train_dataset = dataset_factory.create_dataset(args.model, dataset_root)

    model_factory = ModelFactory()
    model = model_factory.create_model(args.model)

    check_point = os.path.join('check_point', model.name,
                               str(args.scale) + 'x')
    solver = Solver(model, check_point)

    print('Testing...')
    stats, outputs = solver.test(train_dataset)
    export(args.scale, model.name, stats, outputs)