def wrapper_test(model): test_save_root = args.gen_frm_dir clean_fold(test_save_root) loss = 0 count = 0 index = 1 flag = True img_mse, ssim = [], [] for i in range(args.total_length - args.input_length): img_mse.append(0) ssim.append(0) real_input_flag = np.zeros( (args.batch_size, args.total_length - args.input_length - 1, args.img_width // args.patch_size, args.img_width // args.patch_size, args.patch_size ** 2 * args.img_channel)) output_length = args.total_length - args.input_length while flag: dat, (index, b_cup) = sample(batch_size, data_type='test', index=index) dat = nor(dat) tars = dat[:, -output_length:] ims = padding_CIKM_data(dat) ims = preprocess.reshape_patch(ims, args.patch_size) img_gen, _ = model.test(ims, real_input_flag) img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size) img_out = unpadding_CIKM_data(img_gen[:, -output_length:]) mse = np.mean(np.square(tars - img_out)) img_out = de_nor(img_out) loss = loss + mse count = count + 1 bat_ind = 0 for ind in range(index - batch_size, index, 1): save_fold = test_save_root + 'sample_' + str(ind) + '/' clean_fold(save_fold) for t in range(6, 16, 1): imsave(save_fold + 'img_' + str(t) + '.png', img_out[bat_ind, t - 6, :, :, 0]) bat_ind = bat_ind + 1 if b_cup == args.batch_size - 1: pass else: flag = False return loss / count
def wrapper_valid(model): loss = 0 count = 0 index = 1 flag = True img_mse, ssim = [], [] for i in range(args.total_length - args.input_length): img_mse.append(0) ssim.append(0) real_input_flag = np.zeros( (args.batch_size, args.total_length - args.input_length - 1, args.img_width // args.patch_size, args.img_width // args.patch_size, args.patch_size ** 2 * args.img_channel)) output_length = args.total_length - args.input_length while flag: dat, (index, b_cup) = sample(batch_size, data_type='validation', index=index) dat = nor(dat) tars = dat[:, -output_length:] ims = padding_CIKM_data(dat) ims = preprocess.reshape_patch(ims, args.patch_size) img_gen, _ = model.test(ims, real_input_flag) img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size) img_out = unpadding_CIKM_data(img_gen[:, -output_length:]) mse = np.mean(np.square(tars-img_out)) loss = loss+mse count = count+1 if b_cup == args.batch_size-1: pass else: flag = False return loss/count
def test(model, test_input_handle, configs, itr): print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...') test_input_handle.begin(do_shuffle=False) res_path = os.path.join(configs.gen_frm_dir, str(itr)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim = [], [] for i in range(configs.total_length - configs.input_length): img_mse.append(0) ssim.append(0) real_input_flag = np.zeros( (configs.batch_size, configs.total_length - configs.input_length - 1, configs.img_width // configs.patch_size, configs.img_width // configs.patch_size, configs.patch_size**2 * configs.img_channel)) while (test_input_handle.no_batch_left() == False): batch_id = batch_id + 1 test_ims = test_input_handle.get_batch() test_dat = preprocess.reshape_patch(test_ims, configs.patch_size) img_gen = model.test(test_dat, real_input_flag) img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size) output_length = configs.total_length - configs.input_length img_gen_length = img_gen.shape[1] img_out = img_gen[:, -output_length:] # MSE per frame for i in range(output_length): x = test_ims[:, i + configs.input_length, :, :, :] gx = img_out[:, i, :, :, :] gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) for b in range(configs.batch_size): score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True, multichannel=True) ssim[i] += score # save prediction examples if batch_id <= configs.num_save_samples: path = os.path.join(res_path, str(batch_id)) os.mkdir(path) for i in range(configs.total_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) cv2.imwrite(file_name, img_gt) for i in range(img_gen_length): name = 'pd' + str(i + 1 + configs.input_length) + '.png' file_name = os.path.join(path, name) img_pd = img_gen[0, i, :, :, :] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) test_input_handle.next() avg_mse = avg_mse / (batch_id * configs.batch_size) print('mse per seq: ' + str(avg_mse)) for i in range(configs.total_length - configs.input_length): print(img_mse[i] / (batch_id * configs.batch_size)) ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id) print('ssim per frame: ' + str(np.mean(ssim))) for i in range(configs.total_length - configs.input_length): print(ssim[i])
def test(model, test_input_handle, configs, itr): print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...') test_input_handle.begin(do_shuffle=False) res_path = os.path.join(configs.gen_frm_dir, str(itr)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr = [], [], [] lp = [] for i in range(configs.total_length - configs.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) lp.append(0) # reverse schedule sampling if configs.reverse_scheduled_sampling == 1: mask_input = 1 else: mask_input = configs.input_length real_input_flag = np.zeros( (configs.batch_size, configs.total_length - mask_input - 1, configs.img_width // configs.patch_size, configs.img_width // configs.patch_size, configs.patch_size**2 * configs.img_channel)) if configs.reverse_scheduled_sampling == 1: real_input_flag[:, :configs.input_length - 1, :, :] = 1.0 while (test_input_handle.no_batch_left() == False): batch_id = batch_id + 1 test_ims = test_input_handle.get_batch() test_dat = preprocess.reshape_patch(test_ims, configs.patch_size) img_gen = model.test(test_dat, real_input_flag) img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size) output_length = configs.total_length - configs.input_length img_gen_length = img_gen.shape[1] img_out = img_gen[:, -output_length:] # MSE per frame for i in range(output_length): x = test_ims[:, i + configs.input_length, :, :, :] gx = img_out[:, i, :, :, :] gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse # cal lpips img_x = np.zeros( [configs.batch_size, 3, configs.img_width, configs.img_width]) if configs.img_channel == 3: img_x[:, 0, :, :] = x[:, :, :, 0] img_x[:, 1, :, :] = x[:, :, :, 1] img_x[:, 2, :, :] = x[:, :, :, 2] else: img_x[:, 0, :, :] = x[:, :, :, 0] img_x[:, 1, :, :] = x[:, :, :, 0] img_x[:, 2, :, :] = x[:, :, :, 0] img_x = torch.FloatTensor(img_x) img_gx = np.zeros( [configs.batch_size, 3, configs.img_width, configs.img_width]) if configs.img_channel == 3: img_gx[:, 0, :, :] = gx[:, :, :, 0] img_gx[:, 1, :, :] = gx[:, :, :, 1] img_gx[:, 2, :, :] = gx[:, :, :, 2] else: img_gx[:, 0, :, :] = gx[:, :, :, 0] img_gx[:, 1, :, :] = gx[:, :, :, 0] img_gx[:, 2, :, :] = gx[:, :, :, 0] img_gx = torch.FloatTensor(img_gx) lp_loss = loss_fn_alex(img_x, img_gx) lp[i] += torch.mean(lp_loss).item() real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in range(configs.batch_size): score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True, multichannel=True) ssim[i] += score # save prediction examples if batch_id <= configs.num_save_samples: path = os.path.join(res_path, str(batch_id)) os.mkdir(path) for i in range(configs.total_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) cv2.imwrite(file_name, img_gt) for i in range(img_gen_length): name = 'pd' + str(i + 1 + configs.input_length) + '.png' file_name = os.path.join(path, name) img_pd = img_gen[0, i, :, :, :] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) test_input_handle.next() avg_mse = avg_mse / (batch_id * configs.batch_size) print('mse per seq: ' + str(avg_mse)) for i in range(configs.total_length - configs.input_length): print(img_mse[i] / (batch_id * configs.batch_size)) ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id) print('ssim per frame: ' + str(np.mean(ssim))) for i in range(configs.total_length - configs.input_length): print(ssim[i]) psnr = np.asarray(psnr, dtype=np.float32) / batch_id print('psnr per frame: ' + str(np.mean(psnr))) for i in range(configs.total_length - configs.input_length): print(psnr[i]) lp = np.asarray(lp, dtype=np.float32) / batch_id print('lpips per frame: ' + str(np.mean(lp))) for i in range(configs.total_length - configs.input_length): print(lp[i])
def test(model, configs, itr): from data.cloudcast import CloudCast import torch import lpips from skimage.metrics import structural_similarity #from skimage.measure import compare_ssim #import skimage.measure from core.utils import preprocess, metrics import cv2 from tqdm import tqdm loss_fn_alex = lpips.LPIPS(net='alex') device = torch.device("gpu:0" if torch.cuda.is_available() else "cpu") res_path = os.path.join(configs.gen_frm_dir, str(itr)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr = [], [], [] lp = [] testFolder = CloudCast( is_train=False, root="data/", n_frames_input=20, n_frames_output=1, batchsize=8, ) # number of workers will need to be changed testLoader = torch.utils.data.DataLoader( testFolder, batch_size=8, num_workers=configs.number_of_workers, shuffle=False) t_test = tqdm(testLoader, leave=False, total=2) for i in range(configs.total_length - configs.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) lp.append(0) # reverse schedule sampling if configs.reverse_scheduled_sampling == 1: mask_input = 1 else: mask_input = configs.input_length real_input_flag = np.zeros( (configs.batch_size, configs.total_length - mask_input - 1, configs.img_width // configs.patch_size, configs.img_width // configs.patch_size, configs.patch_size**2 * configs.img_channel)) if configs.reverse_scheduled_sampling == 1: real_input_flag[:, :configs.input_length - 1, :, :] = 1.0 for i, (idx, targetVar, inputVar, _, _) in enumerate(t_test): batch_id = batch_id + 1 inputs = inputVar.to(device) test_ims = torch.swapaxes(inputs, 2, 4) test_dat = preprocess.reshape_patch(test_ims, configs.patch_size) img_gen = model.test(test_dat, real_input_flag) img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size) output_length = configs.total_length - configs.input_length img_gen_length = img_gen.shape[1] img_out = img_gen[:, -output_length:] # MSE per frame for i in range(output_length): x = test_ims[:, i + configs.input_length, :, :, :] gx = img_out[:, i, :, :, :] gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse # cal lpips img_x = np.zeros( [configs.batch_size, 3, configs.img_width, configs.img_width]) if configs.img_channel == 3: img_x[:, 0, :, :] = x[:, :, :, 0] img_x[:, 1, :, :] = x[:, :, :, 1] img_x[:, 2, :, :] = x[:, :, :, 2] else: img_x[:, 0, :, :] = x[:, :, :, 0] img_x[:, 1, :, :] = x[:, :, :, 0] img_x[:, 2, :, :] = x[:, :, :, 0] img_x = torch.FloatTensor(img_x) img_gx = np.zeros( [configs.batch_size, 3, configs.img_width, configs.img_width]) if configs.img_channel == 3: img_gx[:, 0, :, :] = gx[:, :, :, 0] img_gx[:, 1, :, :] = gx[:, :, :, 1] img_gx[:, 2, :, :] = gx[:, :, :, 2] else: img_gx[:, 0, :, :] = gx[:, :, :, 0] img_gx[:, 1, :, :] = gx[:, :, :, 0] img_gx[:, 2, :, :] = gx[:, :, :, 0] img_gx = torch.FloatTensor(img_gx) lp_loss = loss_fn_alex(img_x, img_gx) lp[i] += torch.mean(lp_loss).item() real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in range(configs.batch_size): #score = 10 # original method is depricated score, _ = structural_similarity(pred_frm[b], real_frm[b], full=True, multichannel=True) ssim[i] += score # save prediction examples if batch_id <= configs.num_save_samples: path = os.path.join(res_path, str(batch_id)) os.mkdir(path) for i in range(configs.total_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) cv2.imwrite(file_name, img_gt) for i in range(img_gen_length): name = 'pd' + str(i + 1 + configs.input_length) + '.png' file_name = os.path.join(path, name) img_pd = img_gen[0, i, :, :, :] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) avg_mse = avg_mse / (batch_id * configs.batch_size) print('mse per seq: ' + str(avg_mse)) for i in range(configs.total_length - configs.input_length): print(img_mse[i] / (batch_id * configs.batch_size)) ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id) print('ssim per frame: ' + str(np.mean(ssim))) for i in range(configs.total_length - configs.input_length): print(ssim[i]) psnr = np.asarray(psnr, dtype=np.float32) / batch_id print('psnr per frame: ' + str(np.mean(psnr))) for i in range(configs.total_length - configs.input_length): print(psnr[i]) lp = np.asarray(lp, dtype=np.float32) / batch_id print('lpips per frame: ' + str(np.mean(lp))) for i in range(configs.total_length - configs.input_length): print(lp[i])