def main(): display_config() dataset_root = get_full_path(args.scale, args.train_set) print('Contructing dataset...') dataset_factory = DatasetFactory() train_dataset = dataset_factory.create_dataset(args.model, dataset_root) model_factory = ModelFactory() model = model_factory.create_model(args.model) loss_fn = get_loss_fn(model.name) check_point = os.path.join('check_point', model.name, str(args.scale) + 'x') solver = Solver(model, check_point, loss_fn=loss_fn, batch_size=args.batch_size, num_epochs=args.num_epochs, learning_rate=args.learning_rate, fine_tune=args.fine_tune, verbose=args.verbose) print('Training...') solver.train(train_dataset)
def main(): display_config() useGPU = True print('Contructing dataset...') dataset_root = os.path.join('preprocessed_data', args.test_set) dataset_factory = DatasetFactory() val_dataset = dataset_factory.create_dataset('VALID', dataset_root) print('Loading model...') model_path = os.path.join('check_point', args.model, str(args.scale) + 'x', 'model.pt') if not os.path.exists(model_path): raise Exception('Cannot find %s.' % model_path) model = torch.load(model_path) if useGPU: model = model.cuda() else: model = model.cpu() # needed for forward pass! save memory for param in model.parameters(): param.requires_grad = False dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=6) print('Testing...') avrtime = 0 cnt = 0 for i, (input_batch, imgname) in enumerate(dataloader): if useGPU: input_batch = Variable(input_batch.cuda(), requires_grad=False) else: input_batch = Variable(input_batch, requires_grad=False) start = time.time() output_batch, _ = model(input_batch) elapsed_time = time.time() - start avrtime += elapsed_time output_batch = (output_batch.data + 0.5) * 255 # change into pixel domain output_batch = output_batch.cpu().numpy() export(args.scale, model.name, output_batch, imgname[0]) cnt += 1 avrtime = avrtime / cnt print('Average time: %f\n' % avrtime) print('..Finish..')
def main(): display_config() dataset_root = get_full_path(args.scale, args.test_set) print('Contructing dataset...') dataset_factory = DatasetFactory() train_dataset = dataset_factory.create_dataset(args.model, dataset_root) model_factory = ModelFactory() model = model_factory.create_model(args.model) check_point = os.path.join('check_point', model.name, str(args.scale) + 'x') solver = Solver(model, check_point) print('Testing...') stats, outputs = solver.test(train_dataset) export(args.scale, model.name, stats, outputs)