def main(): # =========================================================== # Set train dataset & test dataset # =========================================================== print('===> Loading datasets') train_set = get_training_set(args.upscale_factor) test_set = get_test_set(args.upscale_factor) training_data_loader = DataLoader(dataset=train_set, batch_size=args.batchSize, shuffle=True) testing_data_loader = DataLoader(dataset=test_set, batch_size=args.testBatchSize, shuffle=False) if args.m == 'sub': model = SubPixelTrainer(args, training_data_loader, testing_data_loader) elif args.m == 'srcnn': model = SRCNNTrainer(args, training_data_loader, testing_data_loader) elif args.m == 'vdsr': model = VDSRTrainer(args, training_data_loader, testing_data_loader) elif args.m == 'edsr': model = EDSRTrainer(args, training_data_loader, testing_data_loader) elif args.m == 'fsrcnn': model = FSRCNNTrainer(args, training_data_loader, testing_data_loader) elif args.m == 'drcn': model = DRCNTrainer(args, training_data_loader, testing_data_loader) elif args.m == 'srgan': model = SRGANTrainer(args, training_data_loader, testing_data_loader) else: raise Exception("the model does not exist") model.validate()
def main(): train_csv = "../dataset/l8s2-train.csv" val_csv = "../dataset/l8s2-val.csv" test_csv = "../dataset/l8s2-test.csv" #==================================================================================================== # Dataloader with HDF5 #==================================================================================================== input_transform = transforms.Compose([transforms.ToTensor()]) target_transform = transforms.Compose([ transforms.Lambda( lambda x: [x[i].astype('float32') for i in range(13)]), transforms.Lambda( lambda x: [transforms.ToTensor()(x[i]) for i in range(13)]) ]) train_set = Landsat8DatasetHDF5(train_csv, input_transform=input_transform, target_transform=target_transform) # train_data_loader = DataLoader(dataset=train_set, batch_size=args.batchSize, sampler = LocalRandomSampler(train_set)) train_data_loader = DataLoader(dataset=train_set, batch_size=args.batchSize, shuffle=True) val_set = Landsat8DatasetHDF5(val_csv, input_transform=input_transform, target_transform=target_transform) val_data_loader = DataLoader(dataset=val_set, batch_size=args.testBatchSize, shuffle=False) test_set = Landsat8DatasetHDF5(test_csv, input_transform=input_transform, target_transform=target_transform) test_data_loader = DataLoader(dataset=test_set, batch_size=args.testBatchSize, shuffle=False) #==================================================================================================== if args.model == 'sub': model = SubPixelTrainer(args, train_data_loader, val_data_loader) elif args.model == 'trans': model = TransConvTrainer(args, train_data_loader, val_data_loader) elif args.model == 'submax': model = SubPixelMaxPoolTrainer(args, train_data_loader, val_data_loader) elif args.model == 'transmax': model = TransConvMaxPoolTrainer(args, train_data_loader, val_data_loader) else: raise Exception("the model does not exist") model.run()
def main(): # =========================================================== # Set train dataset & test dataset # =========================================================== print('===> Loading datasets') print("allColors is " + str(allColors)) train_set = get_training_set(args.trainTestFolder, args.upscale_factor, allColors or allLayers or predictColors) test_set = get_test_set(args.trainTestFolder, args.upscale_factor, allColors or allLayers or predictColors) training_data_loader = DataLoader(dataset=train_set, batch_size=args.batchSize, shuffle=True) testing_data_loader = DataLoader(dataset=test_set, batch_size=args.testBatchSize, shuffle=False) if args.model == 'sub': model = SubPixelTrainer(args, training_data_loader, testing_data_loader) elif args.model == 'srcnn': model = SRCNNTrainer(args, training_data_loader, testing_data_loader) elif args.model == 'vdsr': model = VDSRTrainer(args, training_data_loader, testing_data_loader) elif args.model == 'edsr': model = EDSRTrainer(args, training_data_loader, testing_data_loader) elif args.model == 'fsrcnn': model = FSRCNNTrainer(args, training_data_loader, testing_data_loader) elif args.model == 'drcn': model = DRCNTrainer(args, training_data_loader, testing_data_loader) elif args.model == 'srgan': model = SRGANTrainer(args, training_data_loader, testing_data_loader) elif args.model == 'dbpn': model = DBPNTrainer(args, training_data_loader, testing_data_loader) else: raise Exception("the model does not exist") model.run()