def main():
    display_config()

    dataset_root = get_full_path(args.scale, args.train_set)

    print('Contructing dataset...')
    dataset_factory = DatasetFactory()
    train_dataset = dataset_factory.create_dataset(args.model, dataset_root)

    model_factory = ModelFactory()
    model = model_factory.create_model(args.model)

    loss_fn = get_loss_fn(model.name)

    check_point = os.path.join('check_point', model.name,
                               str(args.scale) + 'x')

    solver = Solver(model,
                    check_point,
                    loss_fn=loss_fn,
                    batch_size=args.batch_size,
                    num_epochs=args.num_epochs,
                    learning_rate=args.learning_rate,
                    fine_tune=args.fine_tune,
                    verbose=args.verbose)

    print('Training...')
    solver.train(train_dataset)
def main():
    display_config()

    dataset_root = get_full_path(args.scale, args.test_set)

    print('Contructing dataset...')
    dataset_factory = DatasetFactory()
    train_dataset = dataset_factory.create_dataset(args.model, dataset_root)

    model_factory = ModelFactory()
    model = model_factory.create_model(args.model)

    check_point = os.path.join('check_point', model.name,
                               str(args.scale) + 'x')
    solver = Solver(model, check_point)

    print('Testing...')
    stats, outputs = solver.test(train_dataset)
    export(args.scale, model.name, stats, outputs)
Exemple #3
0
def main():
    display_config()
    print('Contructing dataset...')
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    train_dataset = VSR_Dataset(dir=args.train_set,
                                trans=transforms.Compose([
                                    RandomCrop(48, args.scale),
                                    DataAug(),
                                    ToTensor()
                                ]))
    model_factory = ModelFactory()
    model = model_factory.create_model(args.model)
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print(1.0 * params / (1000 * 1000))

    loss_fn = get_loss_fn(model.name)

    check_point = os.path.join(args.checkpoint, model.name,
                               str(args.scale) + 'x')
    if not os.path.exists(check_point):
        os.makedirs(check_point)

    solver = Solver(model,
                    check_point,
                    model.name,
                    loss_fn=loss_fn,
                    batch_size=args.batch_size,
                    num_epochs=args.num_epochs,
                    learning_rate=args.learning_rate,
                    fine_tune=args.fine_tune,
                    verbose=args.verbose)

    print('Training...')
    val_dataset = VSR_Dataset(dir=args.test_set,
                              trans=transforms.Compose([ToTensor()]))
    solver.train(train_dataset, val_dataset)
Exemple #4
0
parser = argparse.ArgumentParser(description=description)

parser.add_argument('-m', '--model', metavar='M', type=str, default='TDAN',
                    help='network architecture.')
parser.add_argument('-s', '--scale', metavar='S', type=int, default=4, 
                    help='interpolation scale. Default 4')
parser.add_argument('-t', '--test-set', metavar='NAME', type=str, default='/home/cxu-serve/u1/ytian21/project/video_retoration/TDAN-VSR/data/Vid4',
                    help='dataset for testing.')
parser.add_argument('-mp', '--model-path', metavar='MP', type=str, default='model',
                    help='model path.')
parser.add_argument('-sp', '--save-path', metavar='SP', type=str, default='res',
                    help='saving directory path.')
args = parser.parse_args()

model_factory = ModelFactory()
model = model_factory.create_model(args.model)
dir_LR = args.test_set
lis = sorted(os.listdir(dir_LR))
model_path = os.path.join(args.model_path, 'model.pt')
if not os.path.exists(model_path):
    raise Exception('Cannot find %s.' %model_path)
model = torch.load(model_path)
model.eval()
path = args.save_path
if not os.path.exists(path):
            os.makedirs(path)

for i in range(len(lis)):
    print(lis[i])
    LR = os.path.join(dir_LR, lis[i], 'LR_bicubic')
    ims = sorted(os.listdir(LR))