def train_wrapper(model): if args.pretrained_model: model.load(args.pretrained_model) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width, seq_length=args.total_length, is_training=True) eta = args.sampling_start_value for itr in range(1, args.max_iterations + 1): print("Iter number:", itr) if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) ims = train_input_handle.get_batch() ims = preprocess.reshape_patch(ims, args.patch_size) eta, real_input_flag = schedule_sampling(eta, itr) trainer.train(model, ims, real_input_flag, args, itr) if itr % args.snapshot_interval == 0: model.save(itr) if itr % args.test_interval == 0: trainer.test(model, test_input_handle, args, itr) train_input_handle.next()
def wrapper_test(model): test_save_root = args.gen_frm_dir clean_fold(test_save_root) loss = 0 count = 0 index = 1 flag = True img_mse, ssim = [], [] for i in range(args.total_length - args.input_length): img_mse.append(0) ssim.append(0) real_input_flag = np.zeros( (args.batch_size, args.total_length - args.input_length - 1, args.img_width // args.patch_size, args.img_width // args.patch_size, args.patch_size ** 2 * args.img_channel)) output_length = args.total_length - args.input_length while flag: dat, (index, b_cup) = sample(batch_size, data_type='test', index=index) dat = nor(dat) tars = dat[:, -output_length:] ims = padding_CIKM_data(dat) ims = preprocess.reshape_patch(ims, args.patch_size) img_gen, _ = model.test(ims, real_input_flag) img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size) img_out = unpadding_CIKM_data(img_gen[:, -output_length:]) mse = np.mean(np.square(tars - img_out)) img_out = de_nor(img_out) loss = loss + mse count = count + 1 bat_ind = 0 for ind in range(index - batch_size, index, 1): save_fold = test_save_root + 'sample_' + str(ind) + '/' clean_fold(save_fold) for t in range(6, 16, 1): imsave(save_fold + 'img_' + str(t) + '.png', img_out[bat_ind, t - 6, :, :, 0]) bat_ind = bat_ind + 1 if b_cup == args.batch_size - 1: pass else: flag = False return loss / count
def wrapper_train(model): if args.pretrained_model: model.load(args.pretrained_model) # load data # train_input_handle, test_input_handle = datasets_factory.data_provider( # args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width, # seq_length=args.total_length, is_training=True) eta = args.sampling_start_value best_mse = math.inf tolerate = 0 limit = 3 best_iter = None for itr in range(1, args.max_iterations + 1): ims = sample( batch_size=batch_size ) ims = padding_CIKM_data(ims) ims = preprocess.reshape_patch(ims, args.patch_size) ims = nor(ims) eta, real_input_flag = schedule_sampling(eta, itr) cost = trainer.train(model, ims, real_input_flag, args, itr) if itr % args.display_interval == 0: print('itr: ' + str(itr)) print('training loss: ' + str(cost)) if itr % args.test_interval == 0: print('validation one ') valid_mse = wrapper_valid(model) print('validation mse is:',str(valid_mse)) if valid_mse<best_mse: best_mse = valid_mse best_iter = itr tolerate = 0 model.save() else: tolerate = tolerate+1 if tolerate==limit: model.load() test_mse = wrapper_test(model) print('the best valid mse is:',str(best_mse)) print('the test mse is ',str(test_mse)) break
def train_wrapper(model): if args.pretrained_model: model.load(args.pretrained_model) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width, seq_length=args.total_length, is_training=True, ) eta = args.sampling_start_value best_valLoss = 999999999999 best_ssim = -1 best_psnr = -1 for itr in range(1, args.max_iterations + 1): if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) ims = train_input_handle.get_batch() ims = preprocess.reshape_patch(ims, args.patch_size) if args.reverse_scheduled_sampling == 1: real_input_flag = reserve_schedule_sampling_exp(itr) else: eta, real_input_flag = schedule_sampling(eta, itr) trainer.train(model, ims, real_input_flag, args, itr) if itr % args.snapshot_interval == 0: model.save(itr) else: model.save("latest") if itr % args.test_interval == 0: val_loss, ssim, psnr = trainer.test(model, test_input_handle, args, itr) if best_ssim < ssim: best_ssim = ssim model.save("bestssim") print("Best SSIM found: {}".format(best_ssim)) elif best_psnr < psnr: best_psnr = psnr model.save("bestpsnr") print("Best PSNR found: {}".format(best_psnr)) elif best_valLoss > val_loss: best_valLoss = val_loss model.save("bestvalloss") print("Best ValLossMSE found: {}".format(best_valLoss)) train_input_handle.next()
def wrapper_train(model): if args.pretrained_model: model.load(args.pretrained_model) eta = args.sampling_start_value best_mse = math.inf tolerate = 0 limit = 2 best_iter = None for itr in range(1, args.max_iterations + 1): ims = sample( batch_size=batch_size ) ims = padding_CIKM_data(ims) ims = preprocess.reshape_patch(ims, args.patch_size) ims = nor(ims) eta, real_input_flag = schedule_sampling(eta, itr) cost = trainer.train(model, ims, real_input_flag, args, itr) if itr % args.display_interval == 0: print('itr: ' + str(itr)) print('training loss: ' + str(cost)) if itr % args.test_interval == 0: print('validation one ') valid_mse = wrapper_valid(model) print('validation mse is:',str(valid_mse)) if valid_mse<best_mse: best_mse = valid_mse best_iter = itr tolerate = 0 model.save() else: tolerate = tolerate+1 if tolerate==limit: model.load() test_mse = wrapper_test(model) print('the best valid mse is:',str(best_mse)) print('the test mse is ',str(test_mse)) break
def wrapper_valid(model): loss = 0 count = 0 index = 1 flag = True img_mse, ssim = [], [] for i in range(args.total_length - args.input_length): img_mse.append(0) ssim.append(0) real_input_flag = np.zeros( (args.batch_size, args.total_length - args.input_length - 1, args.img_width // args.patch_size, args.img_width // args.patch_size, args.patch_size ** 2 * args.img_channel)) output_length = args.total_length - args.input_length while flag: dat, (index, b_cup) = sample(batch_size, data_type='validation', index=index) dat = nor(dat) tars = dat[:, -output_length:] ims = padding_CIKM_data(dat) ims = preprocess.reshape_patch(ims, args.patch_size) img_gen, _ = model.test(ims, real_input_flag) img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size) img_out = unpadding_CIKM_data(img_gen[:, -output_length:]) mse = np.mean(np.square(tars-img_out)) loss = loss+mse count = count+1 if b_cup == args.batch_size-1: pass else: flag = False return loss/count
def cloud_cast_wrapper(model): from data.cloudcast import CloudCast import torch from tqdm import tqdm trainFolder = CloudCast( is_train=True, root="data/", n_frames_input=20, n_frames_output=1, batchsize=8, ) trainLoader = torch.utils.data.DataLoader( trainFolder, batch_size=8, num_workers=args.number_of_workers, shuffle=False) # device may need to change device = torch.device("gpu:0" if torch.cuda.is_available() else "cpu") t = tqdm(trainLoader, leave=False, total=2) for epoch in range(0, int(args.epochs)): train_loss = 0 for i, (idx, targetVar, inputVar, _, _) in enumerate(t): inputs = inputVar.to(device) inputs = torch.swapaxes(inputs, 2, 4) if args.reverse_scheduled_sampling == 1: real_input_flag = reserve_schedule_sampling_exp(i) ims = preprocess.reshape_patch(inputs, args.patch_size) loss = model.train(ims, real_input_flag) train_loss += loss.item() print(train_loss) # need to add comet #comet.log_metric("train_loss", train_loss / len(args.epoch), epoch=epoch) # runs and generates the validation at each epoch model.save(epoch) test(model, args, epoch)
def test(model, test_input_handle, configs, itr): print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...') test_input_handle.begin(do_shuffle=False) res_path = os.path.join(configs.gen_frm_dir, str(itr)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim = [], [] for i in range(configs.total_length - configs.input_length): img_mse.append(0) ssim.append(0) real_input_flag = np.zeros( (configs.batch_size, configs.total_length - configs.input_length - 1, configs.img_width // configs.patch_size, configs.img_width // configs.patch_size, configs.patch_size**2 * configs.img_channel)) while (test_input_handle.no_batch_left() == False): batch_id = batch_id + 1 test_ims = test_input_handle.get_batch() test_dat = preprocess.reshape_patch(test_ims, configs.patch_size) img_gen = model.test(test_dat, real_input_flag) img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size) output_length = configs.total_length - configs.input_length img_gen_length = img_gen.shape[1] img_out = img_gen[:, -output_length:] # MSE per frame for i in range(output_length): x = test_ims[:, i + configs.input_length, :, :, :] gx = img_out[:, i, :, :, :] gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) for b in range(configs.batch_size): score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True, multichannel=True) ssim[i] += score # save prediction examples if batch_id <= configs.num_save_samples: path = os.path.join(res_path, str(batch_id)) os.mkdir(path) for i in range(configs.total_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) cv2.imwrite(file_name, img_gt) for i in range(img_gen_length): name = 'pd' + str(i + 1 + configs.input_length) + '.png' file_name = os.path.join(path, name) img_pd = img_gen[0, i, :, :, :] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) test_input_handle.next() avg_mse = avg_mse / (batch_id * configs.batch_size) print('mse per seq: ' + str(avg_mse)) for i in range(configs.total_length - configs.input_length): print(img_mse[i] / (batch_id * configs.batch_size)) ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id) print('ssim per frame: ' + str(np.mean(ssim))) for i in range(configs.total_length - configs.input_length): print(ssim[i])
def test(model, test_input_handle, configs, itr): print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...') test_input_handle.begin(do_shuffle=False) res_path = os.path.join(configs.gen_frm_dir, str(itr)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr = [], [], [] lp = [] for i in range(configs.total_length - configs.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) lp.append(0) # reverse schedule sampling if configs.reverse_scheduled_sampling == 1: mask_input = 1 else: mask_input = configs.input_length real_input_flag = np.zeros( (configs.batch_size, configs.total_length - mask_input - 1, configs.img_width // configs.patch_size, configs.img_width // configs.patch_size, configs.patch_size**2 * configs.img_channel)) if configs.reverse_scheduled_sampling == 1: real_input_flag[:, :configs.input_length - 1, :, :] = 1.0 while (test_input_handle.no_batch_left() == False): batch_id = batch_id + 1 test_ims = test_input_handle.get_batch() test_dat = preprocess.reshape_patch(test_ims, configs.patch_size) img_gen = model.test(test_dat, real_input_flag) img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size) output_length = configs.total_length - configs.input_length img_gen_length = img_gen.shape[1] img_out = img_gen[:, -output_length:] # MSE per frame for i in range(output_length): x = test_ims[:, i + configs.input_length, :, :, :] gx = img_out[:, i, :, :, :] gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse # cal lpips img_x = np.zeros( [configs.batch_size, 3, configs.img_width, configs.img_width]) if configs.img_channel == 3: img_x[:, 0, :, :] = x[:, :, :, 0] img_x[:, 1, :, :] = x[:, :, :, 1] img_x[:, 2, :, :] = x[:, :, :, 2] else: img_x[:, 0, :, :] = x[:, :, :, 0] img_x[:, 1, :, :] = x[:, :, :, 0] img_x[:, 2, :, :] = x[:, :, :, 0] img_x = torch.FloatTensor(img_x) img_gx = np.zeros( [configs.batch_size, 3, configs.img_width, configs.img_width]) if configs.img_channel == 3: img_gx[:, 0, :, :] = gx[:, :, :, 0] img_gx[:, 1, :, :] = gx[:, :, :, 1] img_gx[:, 2, :, :] = gx[:, :, :, 2] else: img_gx[:, 0, :, :] = gx[:, :, :, 0] img_gx[:, 1, :, :] = gx[:, :, :, 0] img_gx[:, 2, :, :] = gx[:, :, :, 0] img_gx = torch.FloatTensor(img_gx) lp_loss = loss_fn_alex(img_x, img_gx) lp[i] += torch.mean(lp_loss).item() real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in range(configs.batch_size): score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True, multichannel=True) ssim[i] += score # save prediction examples if batch_id <= configs.num_save_samples: path = os.path.join(res_path, str(batch_id)) os.mkdir(path) for i in range(configs.total_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) cv2.imwrite(file_name, img_gt) for i in range(img_gen_length): name = 'pd' + str(i + 1 + configs.input_length) + '.png' file_name = os.path.join(path, name) img_pd = img_gen[0, i, :, :, :] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) test_input_handle.next() avg_mse = avg_mse / (batch_id * configs.batch_size) print('mse per seq: ' + str(avg_mse)) for i in range(configs.total_length - configs.input_length): print(img_mse[i] / (batch_id * configs.batch_size)) ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id) print('ssim per frame: ' + str(np.mean(ssim))) for i in range(configs.total_length - configs.input_length): print(ssim[i]) psnr = np.asarray(psnr, dtype=np.float32) / batch_id print('psnr per frame: ' + str(np.mean(psnr))) for i in range(configs.total_length - configs.input_length): print(psnr[i]) lp = np.asarray(lp, dtype=np.float32) / batch_id print('lpips per frame: ' + str(np.mean(lp))) for i in range(configs.total_length - configs.input_length): print(lp[i])
def test(model, configs, itr): from data.cloudcast import CloudCast import torch import lpips from skimage.metrics import structural_similarity #from skimage.measure import compare_ssim #import skimage.measure from core.utils import preprocess, metrics import cv2 from tqdm import tqdm loss_fn_alex = lpips.LPIPS(net='alex') device = torch.device("gpu:0" if torch.cuda.is_available() else "cpu") res_path = os.path.join(configs.gen_frm_dir, str(itr)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr = [], [], [] lp = [] testFolder = CloudCast( is_train=False, root="data/", n_frames_input=20, n_frames_output=1, batchsize=8, ) # number of workers will need to be changed testLoader = torch.utils.data.DataLoader( testFolder, batch_size=8, num_workers=configs.number_of_workers, shuffle=False) t_test = tqdm(testLoader, leave=False, total=2) for i in range(configs.total_length - configs.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) lp.append(0) # reverse schedule sampling if configs.reverse_scheduled_sampling == 1: mask_input = 1 else: mask_input = configs.input_length real_input_flag = np.zeros( (configs.batch_size, configs.total_length - mask_input - 1, configs.img_width // configs.patch_size, configs.img_width // configs.patch_size, configs.patch_size**2 * configs.img_channel)) if configs.reverse_scheduled_sampling == 1: real_input_flag[:, :configs.input_length - 1, :, :] = 1.0 for i, (idx, targetVar, inputVar, _, _) in enumerate(t_test): batch_id = batch_id + 1 inputs = inputVar.to(device) test_ims = torch.swapaxes(inputs, 2, 4) test_dat = preprocess.reshape_patch(test_ims, configs.patch_size) img_gen = model.test(test_dat, real_input_flag) img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size) output_length = configs.total_length - configs.input_length img_gen_length = img_gen.shape[1] img_out = img_gen[:, -output_length:] # MSE per frame for i in range(output_length): x = test_ims[:, i + configs.input_length, :, :, :] gx = img_out[:, i, :, :, :] gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse # cal lpips img_x = np.zeros( [configs.batch_size, 3, configs.img_width, configs.img_width]) if configs.img_channel == 3: img_x[:, 0, :, :] = x[:, :, :, 0] img_x[:, 1, :, :] = x[:, :, :, 1] img_x[:, 2, :, :] = x[:, :, :, 2] else: img_x[:, 0, :, :] = x[:, :, :, 0] img_x[:, 1, :, :] = x[:, :, :, 0] img_x[:, 2, :, :] = x[:, :, :, 0] img_x = torch.FloatTensor(img_x) img_gx = np.zeros( [configs.batch_size, 3, configs.img_width, configs.img_width]) if configs.img_channel == 3: img_gx[:, 0, :, :] = gx[:, :, :, 0] img_gx[:, 1, :, :] = gx[:, :, :, 1] img_gx[:, 2, :, :] = gx[:, :, :, 2] else: img_gx[:, 0, :, :] = gx[:, :, :, 0] img_gx[:, 1, :, :] = gx[:, :, :, 0] img_gx[:, 2, :, :] = gx[:, :, :, 0] img_gx = torch.FloatTensor(img_gx) lp_loss = loss_fn_alex(img_x, img_gx) lp[i] += torch.mean(lp_loss).item() real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in range(configs.batch_size): #score = 10 # original method is depricated score, _ = structural_similarity(pred_frm[b], real_frm[b], full=True, multichannel=True) ssim[i] += score # save prediction examples if batch_id <= configs.num_save_samples: path = os.path.join(res_path, str(batch_id)) os.mkdir(path) for i in range(configs.total_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) cv2.imwrite(file_name, img_gt) for i in range(img_gen_length): name = 'pd' + str(i + 1 + configs.input_length) + '.png' file_name = os.path.join(path, name) img_pd = img_gen[0, i, :, :, :] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) avg_mse = avg_mse / (batch_id * configs.batch_size) print('mse per seq: ' + str(avg_mse)) for i in range(configs.total_length - configs.input_length): print(img_mse[i] / (batch_id * configs.batch_size)) ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id) print('ssim per frame: ' + str(np.mean(ssim))) for i in range(configs.total_length - configs.input_length): print(ssim[i]) psnr = np.asarray(psnr, dtype=np.float32) / batch_id print('psnr per frame: ' + str(np.mean(psnr))) for i in range(configs.total_length - configs.input_length): print(psnr[i]) lp = np.asarray(lp, dtype=np.float32) / batch_id print('lpips per frame: ' + str(np.mean(lp))) for i in range(configs.total_length - configs.input_length): print(lp[i])