def train_wrapper(model): if args.pretrained_model: model.load(args.pretrained_model) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width, seq_length=args.total_length, is_training=True) eta = args.sampling_start_value for itr in range(1, args.max_iterations + 1): print("Iter number:", itr) if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) ims = train_input_handle.get_batch() ims = preprocess.reshape_patch(ims, args.patch_size) eta, real_input_flag = schedule_sampling(eta, itr) trainer.train(model, ims, real_input_flag, args, itr) if itr % args.snapshot_interval == 0: model.save(itr) if itr % args.test_interval == 0: trainer.test(model, test_input_handle, args, itr) train_input_handle.next()
def test_wrapper(model): model.load(args.pretrained_model) test_input_handle = datasets_factory.data_provider( args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width, seq_length=args.total_length, is_training=False) trainer.test(model, test_input_handle, args, 'test_result')
def train_wrapper(model): if args.pretrained_model: model.load(args.pretrained_model) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width, seq_length=args.total_length, is_training=True, ) eta = args.sampling_start_value best_valLoss = 999999999999 best_ssim = -1 best_psnr = -1 for itr in range(1, args.max_iterations + 1): if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) ims = train_input_handle.get_batch() ims = preprocess.reshape_patch(ims, args.patch_size) if args.reverse_scheduled_sampling == 1: real_input_flag = reserve_schedule_sampling_exp(itr) else: eta, real_input_flag = schedule_sampling(eta, itr) trainer.train(model, ims, real_input_flag, args, itr) if itr % args.snapshot_interval == 0: model.save(itr) else: model.save("latest") if itr % args.test_interval == 0: val_loss, ssim, psnr = trainer.test(model, test_input_handle, args, itr) if best_ssim < ssim: best_ssim = ssim model.save("bestssim") print("Best SSIM found: {}".format(best_ssim)) elif best_psnr < psnr: best_psnr = psnr model.save("bestpsnr") print("Best PSNR found: {}".format(best_psnr)) elif best_valLoss > val_loss: best_valLoss = val_loss model.save("bestvalloss") print("Best ValLossMSE found: {}".format(best_valLoss)) train_input_handle.next()