예제 #1
0
def main():
    # ===========================================================
    # Set train dataset & test dataset
    # ===========================================================
    print('===> Loading datasets')
    train_set = get_training_set(args.upscale_factor)
    test_set = get_test_set(args.upscale_factor)
    training_data_loader = DataLoader(dataset=train_set,
                                      batch_size=args.batchSize,
                                      shuffle=True)
    testing_data_loader = DataLoader(dataset=test_set,
                                     batch_size=args.testBatchSize,
                                     shuffle=False)

    if args.model == 'sub':
        model = SubPixelTrainer(args, training_data_loader,
                                testing_data_loader)
    elif args.model == 'srcnn':
        model = SRCNNTrainer(args, training_data_loader, testing_data_loader)
    elif args.model == 'vdsr':
        model = VDSRTrainer(args, training_data_loader, testing_data_loader)
    elif args.model == 'edsr':
        model = EDSRTrainer(args, training_data_loader, testing_data_loader)
    elif args.model == 'fsrcnn':
        model = FSRCNNTrainer(args, training_data_loader, testing_data_loader)
    elif args.model == 'drcn':
        model = DRCNTrainer(args, training_data_loader, testing_data_loader)
    elif args.model == 'srgan':
        model = SRGANTrainer(args, training_data_loader, testing_data_loader)
    elif args.model == 'dbpn':
        model = DBPNTrainer(args, training_data_loader, testing_data_loader)
    else:
        raise Exception("the model does not exist")

    model.run()
예제 #2
0
def main():
    print('===> Loading datasets')
    train_set = get_dataset(args.patch_size,
                            args.upscale_factor,
                            phase='train')
    train_loader = DataLoader(dataset=train_set,
                              batch_size=args.bs,
                              shuffle=True)

    test_set = get_dataset(0, args.upscale_factor,
                           phase='test')  # 0 has no meaning
    test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)

    if args.model == 'srcnn':
        model = SRCNNTrainer(args, train_loader, test_loader)
    elif args.model == 'srdcn':
        model = SRDCNTrainer(args, train_loader, test_loader)
    else:
        raise Exception("the model does not exist")

    model.run()