def main(argv=None): if tf.gfile.Exists(FLAGS.save_dir): tf.gfile.DeleteRecursively(FLAGS.save_dir) tf.gfile.MakeDirs(FLAGS.save_dir) if tf.gfile.Exists(FLAGS.gen_frm_dir): tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir) tf.gfile.MakeDirs(FLAGS.gen_frm_dir) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths, FLAGS.batch_size, FLAGS.img_width) print('Initializing models') model = Model() lr = FLAGS.lr delta = 0.00002 base = 0.99998 eta = 1 for itr in range(1, FLAGS.max_iterations + 1): if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) ims = train_input_handle.get_batch() ims = preprocess.reshape_patch(ims, FLAGS.patch_size) if itr < 50000: eta -= delta else: eta = 0.0 random_flip = np.random.random_sample( (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1)) true_token = (random_flip < eta) #true_token = (random_flip < pow(base,itr)) ones = np.ones((FLAGS.img_width / FLAGS.patch_size, FLAGS.img_width / FLAGS.patch_size, FLAGS.patch_size**2 * FLAGS.img_channel)) zeros = np.zeros((FLAGS.img_width / FLAGS.patch_size, FLAGS.img_width / FLAGS.patch_size, FLAGS.patch_size**2 * FLAGS.img_channel)) mask_true = [] for i in range(FLAGS.batch_size): for j in range(FLAGS.seq_length - FLAGS.input_length - 1): if true_token[i, j]: mask_true.append(ones) else: mask_true.append(zeros) mask_true = np.array(mask_true) mask_true = np.reshape( mask_true, (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1, FLAGS.img_width / FLAGS.patch_size, FLAGS.img_width / FLAGS.patch_size, FLAGS.patch_size**2 * FLAGS.img_channel)) cost = model.train(ims, lr, mask_true) if FLAGS.reverse_input: ims_rev = ims[:, ::-1] cost += model.train(ims_rev, lr, mask_true) cost = cost / 2 if itr % FLAGS.display_interval == 0: print('itr: ' + str(itr)) print('training loss: ' + str(cost)) if itr % FLAGS.test_interval == 0: print('test...') test_input_handle.begin(do_shuffle=False) res_path = os.path.join(FLAGS.gen_frm_dir, str(itr)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr, fmae, sharp = [], [], [], [], [] for i in range(FLAGS.seq_length - FLAGS.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) fmae.append(0) sharp.append(0) mask_true = np.zeros( (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1, FLAGS.img_width / FLAGS.patch_size, FLAGS.img_width / FLAGS.patch_size, FLAGS.patch_size**2 * FLAGS.img_channel)) while (test_input_handle.no_batch_left() == False): batch_id = batch_id + 1 test_ims = test_input_handle.get_batch() test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size) img_gen = model.test(test_dat, mask_true) # concat outputs of different gpus along batch img_gen = np.concatenate(img_gen) img_gen = preprocess.reshape_patch_back( img_gen, FLAGS.patch_size) # MSE per frame for i in range(FLAGS.seq_length - FLAGS.input_length): x = test_ims[:, i + FLAGS.input_length, :, :, 0] gx = img_gen[:, i, :, :, 0] fmae[i] += metrics.batch_mae_frame_float(gx, x) gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in range(FLAGS.batch_size): sharp[i] += np.max( cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3))) score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True) ssim[i] += score # save prediction examples if batch_id <= 10: path = os.path.join(res_path, str(batch_id)) os.mkdir(path) for i in range(FLAGS.seq_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) cv2.imwrite(file_name, img_gt) for i in range(FLAGS.seq_length - FLAGS.input_length): name = 'pd' + str(i + 1 + FLAGS.input_length) + '.png' file_name = os.path.join(path, name) img_pd = img_gen[0, i, :, :, :] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) test_input_handle.next() avg_mse = avg_mse / (batch_id * FLAGS.batch_size) print('mse per seq: ' + str(avg_mse)) for i in range(FLAGS.seq_length - FLAGS.input_length): print(img_mse[i] / (batch_id * FLAGS.batch_size)) psnr = np.asarray(psnr, dtype=np.float32) / batch_id fmae = np.asarray(fmae, dtype=np.float32) / batch_id ssim = np.asarray(ssim, dtype=np.float32) / (FLAGS.batch_size * batch_id) sharp = np.asarray( sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id) print('psnr per frame: ' + str(np.mean(psnr))) for i in range(FLAGS.seq_length - FLAGS.input_length): print(psnr[i]) print('fmae per frame: ' + str(np.mean(fmae))) for i in range(FLAGS.seq_length - FLAGS.input_length): print(fmae[i]) print('ssim per frame: ' + str(np.mean(ssim))) for i in range(FLAGS.seq_length - FLAGS.input_length): print(ssim[i]) print('sharpness per frame: ' + str(np.mean(sharp))) for i in range(FLAGS.seq_length - FLAGS.input_length): print(sharp[i]) if itr % FLAGS.snapshot_interval == 0: model.save(itr) train_input_handle.next()
def main(argv=None): if tf.gfile.Exists(FLAGS.save_dir): tf.gfile.DeleteRecursively(FLAGS.save_dir) tf.gfile.MakeDirs(FLAGS.save_dir) if tf.gfile.Exists(FLAGS.gen_frm_dir): tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir) tf.gfile.MakeDirs(FLAGS.gen_frm_dir) if tf.gfile.Exists(FLAGS.log_dir): tf.gfile.DeleteRecursively(FLAGS.log_dir) tf.gfile.MakeDirs(FLAGS.log_dir) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths, FLAGS.batch_size, [FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel], FLAGS.seq_length) print('Initializing models') model = Model() lr = FLAGS.lr delta = 0.0000125 base = 0.99998 eta = 1 # eta = 0.5 #%% for itr in range(1, FLAGS.max_iterations + 1): if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) print('train get_batch:') ims, filename = train_input_handle.get_batch(False) ims = preprocess.reshape_patch(ims, FLAGS.patch_size) if itr < 80000: eta -= delta else: eta = 0.0 random_flip = np.random.random_sample( (FLAGS.batch_size, FLAGS.seq_length)) true_token = (random_flip < eta) ones = np.ones((FLAGS.img_height/FLAGS.patch_size, FLAGS.img_width/FLAGS.patch_size, FLAGS.patch_size**2*FLAGS.img_channel)) zeros = np.zeros((FLAGS.img_height/FLAGS.patch_size, FLAGS.img_width/FLAGS.patch_size, FLAGS.patch_size**2*FLAGS.img_channel)) mask_true = [] for i in range(FLAGS.batch_size): for j in range(FLAGS.seq_length): # 0 2 4 6 8 10 if (j % 2 == 0): mask_true.append(ones) # if iteration bigger it will random mask 1 3 5 7 9 else: if true_token[i, j]: mask_true.append(ones) else: mask_true.append(zeros) # if j < FLAGS.input_length or FLAGS.seq_length-1-j < FLAGS.input_length: # mask_true.append(ones) # else: # if true_token[i,j-10]: # mask_true.append(ones) # else: # mask_true.append(zeros) mask_true = np.array(mask_true) mask_true = np.reshape(mask_true, (FLAGS.batch_size, FLAGS.seq_length, FLAGS.img_height/FLAGS.patch_size, FLAGS.img_width/FLAGS.patch_size, FLAGS.patch_size**2*FLAGS.img_channel)) ###cost = model.train(ims, lr, mask_true) if FLAGS.reverse_input: ims_rev = ims[:,::-1] ###cost += model.train(ims_rev, lr, mask_true) ###cost = cost/2 cost = model.train(ims, ims_rev, lr, mask_true, itr) #tf.summary.scalar('cost', cost) if itr % FLAGS.display_interval == 0: print('itr: ' + str(itr)) print('training loss: ' + str(cost)) if itr % FLAGS.test_interval == 0: print('test...') test_input_handle.begin(do_shuffle = False) res_path = os.path.join(FLAGS.gen_frm_dir, str(itr)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse,ssim,psnr,fmae,sharp= [],[],[],[],[] for i in range(FLAGS.seq_length): img_mse.append(0) ssim.append(0) psnr.append(0) fmae.append(0) sharp.append(0) mask_true = np.ones((FLAGS.batch_size, FLAGS.seq_length, FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel)) for num_batch in range(FLAGS.batch_size): for num_seq in range(FLAGS.seq_length): # 0 2 4 6 8 10 skip if (num_seq % 2 == 0): continue # 1 3 5 7 9 replace random noise else: mask_true[num_batch,num_seq] = np.zeros(( FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel)) mask_true = preprocess.reshape_patch(mask_true, FLAGS.patch_size) ###while(test_input_handle.no_batch_left() == False): while(batch_id <= 10): batch_id = batch_id + 1 print('test get_batch:') test_ims, filename = test_input_handle.get_batch(False) test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size) if FLAGS.reverse_input: test_ims_rev = test_dat[:,::-1] img_gen, ims_watch, ims_rev_watch = model.test(test_dat, test_ims_rev, mask_true, itr) # concat outputs of different gpus along batch img_gen = np.concatenate(img_gen) img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size) ims_watch = np.concatenate(ims_watch) ims_watch = preprocess.reshape_patch_back(ims_watch, FLAGS.patch_size) ims_rev_watch = np.concatenate(ims_rev_watch) ims_rev_watch = preprocess.reshape_patch_back(ims_rev_watch, FLAGS.patch_size) # MSE per frame for i in range(FLAGS.seq_length): x = test_ims[:,i,:,:,0] # Predict only odd images if FLAGS.gen_num == 5: if (i % 2 == 1): gx = img_gen[:,i//2,:,:,0] else: gx = test_ims[:,i,:,:,0] # Predict 11 images elif FLAGS.gen_num == 11: if (i % 2 == 1): gx = img_gen[:,i,:,:,0] else: gx = test_ims[:,i,:,:,0] fmae[i] += metrics.batch_mae_frame_float(gx, x) gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in range(FLAGS.batch_size): sharp[i] += np.max( cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b],3))) score, _ = compare_ssim(pred_frm[b],real_frm[b],full=True) ssim[i] += score # save prediction examples if batch_id <= 10: path = os.path.join(res_path, str(filename)) os.mkdir(path) for i in range(FLAGS.seq_length): name = 'gt' + str(i+1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0,i,:,:,:] * 255) cv2.imwrite(file_name, img_gt) for i in range(FLAGS.seq_length): name = 'pd' + str(i+1) + '.png' file_name = os.path.join(path, name) # Predict only odd images if FLAGS.gen_num == 5: if (i % 2 == 1): img_pd = img_gen[0,i//2,:,:,:] else: img_pd = test_ims[0,i,:,:,:] # Predict 11 images elif FLAGS.gen_num == 11: img_pd = img_gen[0,i,:,:,:] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) name = 'zwgt' + str(i+1) + '.png' file_name = os.path.join(path, name) img_zwgt = np.uint8(ims_watch[0,i,:,:,:] * 255) cv2.imwrite(file_name, img_zwgt) name = 'zwgtrev' + str(i+1) + '.png' file_name = os.path.join(path, name) #print('ims_rev_watch shape =',ims_rev_watch.shape) zwgtrev = np.uint8(ims_rev_watch[0,i,:,:,:] * 255) cv2.imwrite(file_name, zwgtrev) test_input_handle.next() avg_mse = avg_mse / (batch_id*FLAGS.batch_size) print('mse per seq: ' + str(avg_mse)) for i in range(FLAGS.seq_length): print(img_mse[i] / (batch_id*FLAGS.batch_size)) psnr = np.asarray(psnr, dtype=np.float32)/batch_id fmae = np.asarray(fmae, dtype=np.float32)/batch_id ssim = np.asarray(ssim, dtype=np.float32)/(FLAGS.batch_size*batch_id) sharp = np.asarray(sharp, dtype=np.float32)/(FLAGS.batch_size*batch_id) print('psnr per frame: ' + str(np.mean(psnr))) for i in range(FLAGS.seq_length): print(psnr[i]) print('fmae per frame: ' + str(np.mean(fmae))) for i in range(FLAGS.seq_length): print(fmae[i]) print('ssim per frame: ' + str(np.mean(ssim))) for i in range(FLAGS.seq_length): print(ssim[i]) print('sharpness per frame: ' + str(np.mean(sharp))) for i in range(FLAGS.seq_length): print(sharp[i]) if itr % FLAGS.snapshot_interval == 0: model.save(itr) train_input_handle.next()
def main(argv=None): if tf.gfile.Exists(FLAGS.save_dir): tf.gfile.DeleteRecursively(FLAGS.save_dir) tf.gfile.MakeDirs(FLAGS.save_dir) if tf.gfile.Exists(FLAGS.gen_frm_dir): tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir) tf.gfile.MakeDirs(FLAGS.gen_frm_dir) train_data_paths = os.path.join( FLAGS.train_data_paths, FLAGS.dataset_name, 'train_speed_down_sample{}.npz'.format(FLAGS.down_sample)) valid_data_paths = os.path.join( FLAGS.valid_data_paths, FLAGS.dataset_name, 'valid_speed_down_sample{}.npz'.format(FLAGS.down_sample)) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, train_data_paths, valid_data_paths, FLAGS.batch_size, True, FLAGS.down_sample, FLAGS.input_length, FLAGS.seq_length - FLAGS.input_length) cities = ['Berlin', 'Istanbul', 'Moscow'] # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287) # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow # lie in UTC+3. utcPlus2 = [30, 69, 126, 186, 234] utcPlus3 = [57, 114, 174, 222, 258] indicies = utcPlus3 if FLAGS.dataset_name == 'Berlin': indicies = utcPlus2 dims = train_input_handle.dims FLAGS.img_height = dims[-2] FLAGS.img_width = dims[-1] print("Initializing models", flush=True) model = Model() lr = FLAGS.lr delta = 0.00002 base = 0.99998 eta = 1 for itr in range(1, FLAGS.max_iterations + 1): if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) ims = train_input_handle.get_batch() ims = preprocess.reshape_patch(ims, FLAGS.patch_size) if itr < 50000: eta -= delta else: eta = 0.0 random_flip = np.random.random_sample( (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1)) true_token = (random_flip < eta) #true_token = (random_flip < pow(base,itr)) ones = np.ones((FLAGS.img_height, FLAGS.img_width, int(FLAGS.patch_size**2 * FLAGS.img_channel))) zeros = np.zeros((int(FLAGS.img_height), int(FLAGS.img_width), int(FLAGS.patch_size**2 * FLAGS.img_channel))) mask_true = [] for i in range(FLAGS.batch_size): for j in range(FLAGS.seq_length - FLAGS.input_length - 1): if true_token[i, j]: mask_true.append(ones) else: mask_true.append(zeros) mask_true = np.array(mask_true) mask_true = np.reshape( mask_true, (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1, int(FLAGS.img_height), int(FLAGS.img_width), int(FLAGS.patch_size**2 * FLAGS.img_channel))) cost = model.train(ims, lr, mask_true) if FLAGS.reverse_input: ims_rev = ims[:, ::-1] cost += model.train(ims_rev, lr, mask_true) cost = cost / 2 if itr % FLAGS.display_interval == 0: print('itr: ' + str(itr), flush=True) print('training loss: ' + str(cost), flush=True) if itr % FLAGS.test_interval == 0: print('test...', flush=True) test_input_handle.begin(do_shuffle=False) res_path = os.path.join(FLAGS.gen_frm_dir, str(itr)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr, fmae, sharp = [], [], [], [], [] for i in range(FLAGS.seq_length - FLAGS.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) fmae.append(0) sharp.append(0) mask_true = np.zeros( (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1, FLAGS.img_height, FLAGS.img_width, FLAGS.patch_size**2 * FLAGS.img_channel)) while (test_input_handle.no_batch_left() == False): batch_id = batch_id + 1 test_ims = test_input_handle.get_batch() test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size) img_gen = model.test(test_dat, mask_true) # concat outputs of different gpus along batch img_gen = np.concatenate(img_gen) img_gen = preprocess.reshape_patch_back( img_gen, FLAGS.patch_size) # MSE per frame for i in range(FLAGS.seq_length - FLAGS.input_length): x = test_ims[:, i + FLAGS.input_length, :, :, 0] gx = img_gen[:, i, :, :, 0] fmae[i] += metrics.batch_mae_frame_float(gx, x) gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in range(FLAGS.batch_size): sharp[i] += np.max( cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3))) # score, _ = compare_ssim(pred_frm[b],real_frm[b],full=True) # ssim[i] += score # save prediction examples if batch_id <= 10: path = os.path.join(res_path, str(batch_id)) os.mkdir(path) for i in range(FLAGS.seq_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) cv2.imwrite(file_name, img_gt) for i in range(FLAGS.seq_length - FLAGS.input_length): name = 'pd' + str(i + 1 + FLAGS.input_length) + '.png' file_name = os.path.join(path, name) img_pd = img_gen[0, i, :, :, :] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) test_input_handle.next() avg_mse = avg_mse / (batch_id * FLAGS.batch_size) print('mse per seq: ' + str(avg_mse), flush=True) for i in range(FLAGS.seq_length - FLAGS.input_length): print(img_mse[i] / (batch_id * FLAGS.batch_size)) psnr = np.asarray(psnr, dtype=np.float32) / batch_id fmae = np.asarray(fmae, dtype=np.float32) / batch_id sharp = np.asarray( sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id) print('psnr per frame: ' + str(np.mean(psnr)), flush=True) for i in range(FLAGS.seq_length - FLAGS.input_length): print(psnr[i], flush=True) print('fmae per frame: ' + str(np.mean(fmae))) for i in range(FLAGS.seq_length - FLAGS.input_length): print(fmae[i], flush=True) print('sharpness per frame: ' + str(np.mean(sharp))) for i in range(FLAGS.seq_length - FLAGS.input_length): print(sharp[i], flush=True) # test with file valid_data_path = os.path.join( FLAGS.train_data_paths, FLAGS.dataset_name, '{}_validation'.format(FLAGS.dataset_name)) files = list_filenames(valid_data_path) output_all = [] labels_all = [] for f in files: valid_file = valid_data_path + '/' + f valid_input, raw_output = datasets_factory.test_validation_provider( valid_file, indicies, down_sample=FLAGS.down_sample, seq_len=FLAGS.input_length, horizon=FLAGS.seq_length - FLAGS.input_length) valid_input = valid_input.astype(np.float) / 255.0 labels_all.append(raw_output) num_tests = len(indicies) num_partitions = int(np.ceil(num_tests / FLAGS.batch_size)) for i in range(num_partitions): valid_input_i = valid_input[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] num_input_i = valid_input_i.shape[0] if num_input_i < FLAGS.batch_size: zeros_fill_in = np.zeros( (FLAGS.batch_size - num_input_i, FLAGS.seq_length, FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel)) valid_input_i = np.concatenate( [valid_input_i, zeros_fill_in], axis=0) img_gen = model.test(valid_input_i, mask_true) output_all.append(img_gen[0][:num_input_i]) output_all = np.concatenate(output_all, axis=0) labels_all = np.concatenate(labels_all, axis=0) origin_height = labels_all.shape[-2] origin_width = labels_all.shape[-3] output_resize = [] for i in range(output_all.shape[0]): output_i = [] for j in range(output_all.shape[1]): tmp_data = output_all[i, j, 1, :, :] tmp_data = cv2.resize(tmp_data, (origin_height, origin_width)) tmp_data = np.expand_dims(tmp_data, axis=0) output_i.append(tmp_data) output_i = np.stack(output_i, axis=0) output_resize.append(output_i) output_resize = np.stack(output_resize, axis=0) output_resize *= 255.0 labels_all = np.expand_dims(labels_all[..., 1], axis=2) valid_mse = masked_mse_np(output_resize, labels_all, np.nan) print("validation mse is ", valid_mse, flush=True) if itr % FLAGS.snapshot_interval == 0: model.save(itr) train_input_handle.next()
def main(argv=None): # FLAGS.save_dir += FLAGS.dataset_name # FLAGS.gen_frm_dir += FLAGS.dataset_name # if tf.io.gfile.exists(FLAGS.save_dir): # tf.io.gfile.rmtree(FLAGS.save_dir) # tf.io.gfile.makedirs(FLAGS.save_dir) # if tf.io.gfile.exists(FLAGS.gen_frm_dir): # tf.io.gfile.rmtree(FLAGS.gen_frm_dir) # tf.io.gfile.makedirs(FLAGS.gen_frm_dir) FLAGS.save_dir += FLAGS.dataset_name + str( FLAGS.seq_length) + FLAGS.num_hidden FLAGS.best_model = FLAGS.save_dir + f'/best_channels{FLAGS.img_channel}_weighted.ckpt' FLAGS.gen_frm_dir += FLAGS.dataset_name if not tf.io.gfile.exists(FLAGS.save_dir): # tf.io.gfile.rmtree(FLAGS.save_dir) tf.io.gfile.makedirs(FLAGS.save_dir) else: FLAGS.pretrained_model = FLAGS.save_dir if not tf.io.gfile.exists(FLAGS.gen_frm_dir): # tf.io.gfile.rmtree(FLAGS.gen_frm_dir) tf.io.gfile.makedirs(FLAGS.gen_frm_dir) process_data_dir = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name, 'process_0.5') node_pos_file_2in1 = os.path.join(process_data_dir, 'node_pos_0.5.npy') node_pos = np.load(node_pos_file_2in1) train_data_paths = os.path.join(FLAGS.train_data_paths, FLAGS.dataset_name, FLAGS.dataset_name + '_training') valid_data_paths = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name, FLAGS.dataset_name + '_validation') # load data train_input_handle, test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, train_data_paths, valid_data_paths, FLAGS.batch_file, True, FLAGS.input_length, FLAGS.seq_length - FLAGS.input_length) cities = ['Berlin', 'Istanbul', 'Moscow'] # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287) # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow # lie in UTC+3. utcPlus2 = [30, 69, 126, 186, 234] utcPlus3 = [57, 114, 174, 222, 258] indicies = utcPlus3 if FLAGS.dataset_name == 'Berlin': indicies = utcPlus2 # dims = train_input_handle.dims print("Initializing models", flush=True) model = Model() lr = FLAGS.lr delta = 0.2 base = 0.99998 eta = 1 min_val_loss = 1.0 for itr in range(1, FLAGS.max_iterations + 1): if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) imss = train_input_handle.get_batch() imss = imss[..., :FLAGS.img_channel] imss = preprocess.reshape_patch(imss, FLAGS.patch_size_width, FLAGS.patch_size_height) num_batches = imss.shape[0] for bi in range(0, num_batches, FLAGS.batch_size): ims = imss[bi:bi + FLAGS.batch_size] FLAGS.img_height = ims.shape[2] FLAGS.img_width = ims.shape[3] batch_size = ims.shape[0] if itr < 10: eta -= delta else: eta = 0.0 random_flip = np.random.random_sample( (batch_size, FLAGS.seq_length - FLAGS.input_length - 1)) true_token = (random_flip < eta) ones = np.ones((FLAGS.img_height, FLAGS.img_width, int(FLAGS.patch_size_height * FLAGS.patch_size_width * FLAGS.img_channel))) zeros = np.zeros((int(FLAGS.img_height), int(FLAGS.img_width), int(FLAGS.patch_size_height * FLAGS.patch_size_width * FLAGS.img_channel))) mask_true = [] for i in range(batch_size): for j in range(FLAGS.seq_length - FLAGS.input_length - 1): if true_token[i, j]: mask_true.append(ones) else: mask_true.append(zeros) mask_true = np.array(mask_true) mask_true = np.reshape( mask_true, (batch_size, FLAGS.seq_length - FLAGS.input_length - 1, int(FLAGS.img_height), int(FLAGS.img_width), int(FLAGS.patch_size_height * FLAGS.patch_size_width * FLAGS.img_channel))) cost = model.train(ims, lr, mask_true, batch_size) if FLAGS.reverse_input: ims_rev = ims[:, ::-1] cost += model.train(ims_rev, lr, mask_true, batch_size) cost = cost / 2 # cost = cost / (batch_size * FLAGS.img_height * FLAGS.img_width * FLAGS.patch_size_height * # FLAGS.patch_size_width * FLAGS.img_channel * (FLAGS.seq_length - 1)) if itr % FLAGS.display_interval == 0: print('itr: ' + str(itr), flush=True) print('training loss: ' + str(cost), flush=True) train_input_handle.next() if itr % FLAGS.test_interval == 0: print('test...', flush=True) batch_size = len(indicies) test_input_handle.begin(do_shuffle=False) # res_path = os.path.join(FLAGS.gen_frm_dir, str(itr)) # os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr, fmae, sharp = [], [], [], [], [] for i in range(FLAGS.seq_length - FLAGS.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) fmae.append(0) sharp.append(0) mask_true = np.zeros( (batch_size, FLAGS.seq_length - FLAGS.input_length - 1, FLAGS.img_height, FLAGS.img_width, FLAGS.patch_size_height * FLAGS.patch_size_width * FLAGS.img_channel)) gt_list = [] pred_list = [] while (test_input_handle.no_batch_left() == False): batch_id = batch_id + 1 test_ims = test_input_handle.get_test_batch(indicies) # get the selected channels test_ims = test_ims[..., :FLAGS.img_channel] gt_list.append(test_ims[:, FLAGS.input_length:, :, :, :]) test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size_width, FLAGS.patch_size_height) img_gen = model.test(test_dat, mask_true, batch_size) # concat outputs of different gpus along batch img_gen = np.concatenate(img_gen) img_gen = preprocess.reshape_patch_back( img_gen, FLAGS.patch_size_width, FLAGS.patch_size_height) pred_list.append(img_gen) # MSE per frame for i in range(FLAGS.seq_length - FLAGS.input_length): x = test_ims[:, i + FLAGS.input_length, :, :, :] gx = img_gen[:, i, :, :, :] fmae[i] += metrics.batch_mae_frame_float(gx, x) gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) test_input_handle.next() avg_mse = avg_mse / (batch_id * batch_size * FLAGS.img_height * FLAGS.img_width * FLAGS.patch_size_height * FLAGS.patch_size_width * FLAGS.img_channel * len(img_mse)) print('mse per seq: ' + str(avg_mse), flush=True) for i in range(FLAGS.seq_length - FLAGS.input_length): print(img_mse[i] / (batch_id * batch_size * FLAGS.img_height * FLAGS.img_width * FLAGS.patch_size_height * FLAGS.patch_size_width * FLAGS.img_channel), flush=True) gt_list = np.stack(gt_list, axis=0) pred_list = np.stack(pred_list, axis=0) mse = masked_mse_np(pred_list, gt_list, null_val=np.nan) volume_mse = masked_mse_np(pred_list[..., 0], gt_list[..., 0], null_val=np.nan) speed_mse = masked_mse_np(pred_list[..., 1], gt_list[..., 1], null_val=np.nan) print("The output mse is ", mse, flush=True) print("The volume mse is ", volume_mse, flush=True) print("The speed mse is ", speed_mse, flush=True) if FLAGS.img_channel == 3: direction_mse = masked_mse_np(pred_list[..., 2], gt_list[..., 2], null_val=np.nan) print("The direction mse is ", direction_mse, flush=True) if min_val_loss > mse: min_val_loss = mse print("Current Min Val Loss is ", min_val_loss) model.save_to_best_mode() if itr % FLAGS.snapshot_interval == 0: model.save(itr)
def test(model, test_input_handle, configs, itr=None): print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...') res_path = os.path.join(configs.gen_frm_dir, str(itr)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr, fmae, sharp = [], [], [], [], [] output_length = configs.total_length - configs.input_length #20-10 for i in range(configs.total_length - configs.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) fmae.append(0) sharp.append(0) real_input_flag = np.zeros( #(4 , 20-10-1 , 140//4 , 140//4 , 2^2*1) (configs.batch_size, configs.total_length - configs.input_length - 1, configs.img_width // configs.patch_size, configs.img_width // configs.patch_size, configs.patch_size**2 * configs.img_channel)) for ind, test_input in enumerate(test_input_handle): test_ims = test_input.numpy( ) # test_ims shape: (batch, seq, channels, height, width) test_ims = np.transpose(test_ims, (0, 1, 3, 4, 2)) batch_id = batch_id + 1 test_dat = preprocess.reshape_patch(test_input, configs.patch_size) img_gen = model.test(test_dat, real_input_flag) img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size) img_out = img_gen[:, -output_length:] # MSE per frame for i in range(output_length): x = test_ims[:, i + configs.input_length, :, :, :] gx = img_out[:, i, :, :, :] fmae[i] += metrics.batch_mae_frame_float(gx, x) gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in range(configs.batch_size): score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True, multichannel=True) ssim[i] += score sharp[i] += np.max( cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3))) # save prediction examples if batch_id <= configs.num_save_samples: path = os.path.join(res_path, str(batch_id)) os.mkdir(path) for i in range(configs.total_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) cv2.imwrite(file_name, img_gt) for i in range(output_length): name = 'pd' + str(i + 1 + configs.input_length) + '.png' file_name = os.path.join(path, name) img_pd = img_out[0, i, :, :, :] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) avg_mse = avg_mse / (batch_id * configs.batch_size) print('mse per seq: ' + str(avg_mse)) for i in range(configs.total_length - configs.input_length): print(img_mse[i] / (batch_id * configs.batch_size)) ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id) psnr = np.asarray(psnr, dtype=np.float32) / batch_id fmae = np.asarray(fmae, dtype=np.float32) / batch_id sharp = np.asarray(sharp, dtype=np.float32) / (configs.batch_size * batch_id) print('ssim per frame: ' + str(np.mean(ssim))) for i in range(configs.total_length - configs.input_length): print(ssim[i]) print('psnr per frame: ' + str(np.mean(psnr))) for i in range(configs.total_length - configs.input_length): print(psnr[i]) print('fmae per frame: ' + str(np.mean(fmae))) for i in range(configs.total_length - configs.input_length): print(fmae[i]) print('sharpness per frame: ' + str(np.mean(sharp))) for i in range(configs.total_length - configs.input_length): print(sharp[i]) return avg_mse, ssim, psnr, fmae, sharp
def main(argv=None): if not tf.gfile.Exists(FLAGS.save_dir): tf.gfile.MakeDirs(FLAGS.save_dir) if not tf.gfile.Exists(FLAGS.gen_dir): tf.gfile.MakeDirs(FLAGS.gen_dir) print 'start training !', time.strftime('%Y-%m-%d %H:%M:%S\n\n\n', time.localtime(time.time())) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths, FLAGS.batch_size * FLAGS.n_gpu, FLAGS.joints_number, FLAGS.input_length, FLAGS.seq_length, is_training=True) print('Initializing models') model = Model() lr = FLAGS.lr train_time = 0 test_time_all = 0 folder = 1 path_bak = FLAGS.bak_dir for itr in range(1, FLAGS.max_iterations + 1): if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) if itr % 20000 == 0: lr = lr * 0.95 start_time = time.time() ims = train_input_handle.get_batch() ims = ims[:, :, 0:22, :] pretrain_iter = 0 if itr < pretrain_iter: inputs1 = ims else: inputs1 = ims[:, 0:FLAGS.input_length, :, :] tem = ims[:, FLAGS.input_length - 1] tem = np.expand_dims(tem, axis=1) tem = np.repeat(tem, FLAGS.seq_length - FLAGS.input_length, axis=1) inputs1 = np.concatenate((inputs1, tem), axis=1) #pdb.set_trace() inputs2 = ims[:, FLAGS.input_length:] inputs = np.concatenate((inputs1, inputs2), axis=1) ims_list = np.split(inputs, FLAGS.n_gpu) cost = model.train(ims_list, lr, 1) # inverse the input sequence imv1 = ims[:, ::-1] if itr >= pretrain_iter: imv_rev1 = imv1[:, 0:FLAGS.input_length, :, :] #pdb.set_trace() tem = imv1[:, FLAGS.input_length - 1] tem = np.expand_dims(tem, axis=1) tem = np.repeat(tem, FLAGS.seq_length - FLAGS.input_length, axis=1) #pdb.set_trace() imv_rev1 = np.concatenate((imv_rev1, tem), axis=1) else: imv_rev1 = imv1 imv_rev2 = imv1[:, FLAGS.input_length:] ims_rev1 = np.concatenate((imv_rev1, imv_rev2), axis=1) ims_rev1 = np.split(ims_rev1, FLAGS.n_gpu) cost += model.train(ims_rev1, lr, 1) cost = cost / 2 end_time = time.time() t = end_time - start_time train_time += t if itr % FLAGS.display_interval == 0: print('itr: ' + str(itr) + ' lr: ' + str(lr) + ' training loss: ' + str(cost)) if itr % FLAGS.test_interval == 0: print('train time:' + str(train_time)) print('test...') str1 = 'walking eating smoking discussion directions greeting phoning posing purchases sitting sittingdown takingphoto waiting walkingdog walkingtogether' str1 = str1.split(' ') res_path = os.path.join(FLAGS.gen_dir, str(itr)) if not tf.gfile.Exists(res_path): os.mkdir(res_path) avg_mse = 0 batch_id = 0 test_time = 0 joint_mse = np.zeros((25, 32)) joint_mae = np.zeros((25, 32)) mpjpe = np.zeros([1, FLAGS.seq_length - FLAGS.input_length]) mpjpe_l = np.zeros([1, FLAGS.seq_length - FLAGS.input_length]) img_mse, ssim, psnr, fmae, sharp = [], [], [], [], [] for i in range(FLAGS.seq_length - FLAGS.input_length): img_mse.append(0) fmae.append(0) f = 0 for s in str1: start_time1 = time.time() batch_id = batch_id + 1 mpjpe1 = np.zeros([1, FLAGS.seq_length - FLAGS.input_length]) tem = np.load(FLAGS.test_data_paths + '/' + s + '.npy') tem = np.repeat(tem, (FLAGS.batch_size * FLAGS.n_gpu) / 8, axis=0) test_ims = tem[:, 0:FLAGS.seq_length, :, :] test_ims1 = test_ims test_ims = test_ims[:, :, 0:22, :] #test_dat=test_ims #test_dat=test_ims[:,:,0:22,:] #pdb.set_trace() test_dat = test_ims[:, 0:FLAGS.input_length, :, :] #test_dat=test_dat[:,:,0:22,:] tem = test_dat[:, FLAGS.input_length - 1] tem = np.expand_dims(tem, axis=1) tem = np.repeat(tem, FLAGS.seq_length - FLAGS.input_length, axis=1) test_dat1 = np.concatenate((test_dat, tem), axis=1) test_dat2 = test_ims[:, FLAGS.input_length:] test_dat = np.concatenate((test_dat1, test_dat2), axis=1) test_dat = np.split(test_dat, FLAGS.n_gpu) img_gen = model.test(test_dat, 0) end_time1 = time.time() t1 = end_time1 - start_time1 test_time += t1 # concat outputs of different gpus along batch img_gen = np.concatenate(img_gen) gt_frm = test_ims1[:, FLAGS.input_length:] img_gen = recoverh36m_3d.recoverh36m_3d(gt_frm, img_gen) # mpjpe1=np.zeros([1,FLAGS.seq_length - FLAGS.input_length]) # MSE per frame for i in range(FLAGS.seq_length - FLAGS.input_length): x = gt_frm[:, i, :, ] gx = img_gen[:, i, :, ] fmae[i] += metrics.batch_mae_frame_float(gx, x) mse = np.square(x - gx).sum() for j in range(FLAGS.batch_size * FLAGS.n_gpu): tem1 = 0 for k in range(32): tem1 += np.sqrt( np.square(x[j, k] - gx[j, k]).sum()) mpjpe1[0, i] += tem1 / 32 img_mse[i] += mse avg_mse += mse real_frm = x pred_frm = gx for j in range(32): xi = x[:, j] gxi = gx[:, j] joint_mse[i, j] += np.square(xi - gxi).sum() joint_mae[i, j] += metrics.batch_mae_frame_float1( gxi, xi) # save prediction examples path = os.path.join(res_path, str(batch_id)) if not tf.gfile.Exists(path): os.mkdir(path) for i in range(FLAGS.seq_length): name = 'gt' + str(i + 1) + '.mat' file_name = os.path.join(path, name) img_gt = test_ims[0, i, :, :] io.savemat(file_name, {'joint': img_gt}) for i in range(FLAGS.seq_length - FLAGS.input_length): name = 'pd' + str(i + 1 + FLAGS.input_length) + '.mat' file_name = os.path.join(path, name) img_pd = img_gen[0, i, :, :] io.savemat(file_name, {'joint': img_pd}) mpjpe1 = mpjpe1 / (FLAGS.batch_size * FLAGS.n_gpu) print 'current action mpjpe: ', s for i in mpjpe1[0]: print i mpjpe += mpjpe1 if f <= 3: print 'four actions', s mpjpe_l += mpjpe1 f = f + 1 test_time_all += test_time joint_mae = np.asarray(joint_mae, dtype=np.float32) / batch_id joint_mse = np.asarray(joint_mse, dtype=np.float32) / ( batch_id * FLAGS.batch_size * FLAGS.n_gpu) avg_mse = avg_mse / (batch_id * FLAGS.batch_size * FLAGS.n_gpu) print('mse per seq: ' + str(avg_mse)) #for i in range(FLAGS.seq_length - FLAGS.input_length): # print(img_mse[i] / (batch_id * FLAGS.batch_size * FLAGS.n_gpu)) mpjpe = mpjpe / (batch_id) print('mean per joints position error: ' + str(np.mean(mpjpe))) for i in range(FLAGS.seq_length - FLAGS.input_length): print(mpjpe[0, i]) mpjpe_l = mpjpe_l / 4 print('mean mpjpe for four actions: ' + str(np.mean(mpjpe_l))) for i in range(FLAGS.seq_length - FLAGS.input_length): print(mpjpe_l[0, i]) fmae = np.asarray(fmae, dtype=np.float32) / batch_id print('fmae per frame: ' + str(np.mean(fmae))) #for i in range(FLAGS.seq_length - FLAGS.input_length): # print(fmae[i]) print 'current test time:' + str(test_time) print 'all test time: ' + str(test_time_all) filename = os.path.join(res_path, 'test_result') io.savemat(filename, { 'joint_mse': joint_mse, 'joint_mae': joint_mae, 'mpjpe': mpjpe }) if itr % FLAGS.snapshot_interval == 0: model.save(itr) print 'model saving done! ', time.strftime( '%Y-%m-%d %H:%M:%S\n\n\n', time.localtime(time.time())) if itr % (5 * FLAGS.snapshot_interval) == 0: bakfile = path_bak + '/' + str(folder) shutil.copytree(FLAGS.save_dir, bakfile) folder = folder + 1 train_input_handle.next()
def main(argv=None): tf.disable_eager_execution() #toegevoegd anders error # load data _, test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.test_data_paths, FLAGS.batch_size, FLAGS.img_width) print("Initializing models") model = Model() lr = FLAGS.lr print('test...') test_input_handle.begin(do_shuffle=False) res_path = os.path.join(FLAGS.gen_frm_dir, 'test') os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr, fmae, sharp = [], [], [], [], [] for i in xrange(FLAGS.seq_length - FLAGS.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) fmae.append(0) sharp.append(0) mask_true = np.zeros( (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1, int(FLAGS.img_height / FLAGS.patch_size), int(FLAGS.img_width / FLAGS.patch_size), FLAGS.patch_size**2 * FLAGS.img_channel)) while (test_input_handle.no_batch_left() == False): batch_id = batch_id + 1 test_ims = test_input_handle.get_batch() test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size) img_gen = model.test(test_dat, mask_true) # concat outputs of different gpus along batch img_gen = np.concatenate(img_gen) img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size) # MSE per frame for i in xrange(FLAGS.seq_length - FLAGS.input_length): x = test_ims[:, i + FLAGS.input_length, :, :, 0] gx = img_gen[:, i, :, :, 0] fmae[i] += metrics.batch_mae_frame_float(gx, x) gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in xrange(FLAGS.batch_size): sharp[i] += np.max( cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3))) score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True) ssim[i] += score # save prediction examples if batch_id <= 10: path = os.path.join(res_path, str(batch_id)) os.mkdir(path) for i in xrange(FLAGS.seq_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) cv2.imwrite(file_name, img_gt) for i in xrange(FLAGS.seq_length - FLAGS.input_length): name = 'pd' + str(i + 1 + FLAGS.input_length) + '.png' file_name = os.path.join(path, name) img_pd = img_gen[0, i, :, :, :] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) test_input_handle.next() avg_mse = avg_mse / (batch_id * FLAGS.batch_size) print('mse per seq: ' + str(avg_mse)) for i in xrange(FLAGS.seq_length - FLAGS.input_length): print(img_mse[i] / (batch_id * FLAGS.batch_size)) psnr = np.asarray(psnr, dtype=np.float32) / batch_id fmae = np.asarray(fmae, dtype=np.float32) / batch_id ssim = np.asarray(ssim, dtype=np.float32) / (FLAGS.batch_size * batch_id) sharp = np.asarray(sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id) print('psnr per frame: ' + str(np.mean(psnr))) for i in xrange(FLAGS.seq_length - FLAGS.input_length): print(psnr[i]) print('fmae per frame: ' + str(np.mean(fmae))) for i in xrange(FLAGS.seq_length - FLAGS.input_length): print(fmae[i]) print('ssim per frame: ' + str(np.mean(ssim))) for i in xrange(FLAGS.seq_length - FLAGS.input_length): print(ssim[i]) print('sharpness per frame: ' + str(np.mean(sharp))) for i in xrange(FLAGS.seq_length - FLAGS.input_length): print(sharp[i])
def main(argv=None): if tf.gfile.Exists(FLAGS.gen_frm_dir): tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir) tf.gfile.MakeDirs(FLAGS.gen_frm_dir) if tf.gfile.Exists(FLAGS.log_dir): tf.gfile.DeleteRecursively(FLAGS.log_dir) tf.gfile.MakeDirs(FLAGS.log_dir) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths, FLAGS.batch_size, [FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel], FLAGS.seq_length) print('Initializing models') model = Model() #%% test_input_handle.begin(do_shuffle=False) totalDataLen = int(test_input_handle.total()) print('totalDataLen=', totalDataLen) for itr in range(1): print('inference...') res_path = os.path.join(FLAGS.gen_frm_dir, 'images' + str(itr)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr, fmae, sharp = [], [], [], [], [] for i in range(FLAGS.seq_length): img_mse.append(0) ssim.append(0) psnr.append(0) fmae.append(0) sharp.append(0) mask_true = np.ones( (FLAGS.batch_size, FLAGS.seq_length, FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel)) for num_batch in range(FLAGS.batch_size): for num_seq in range(FLAGS.seq_length): # 0 2 4 6 8 10 skip if (num_seq % 2 == 0): continue # 1 3 5 7 9 replace random noise else: mask_true[num_batch, num_seq] = np.zeros( (FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel)) mask_true = preprocess.reshape_patch(mask_true, FLAGS.patch_size) ###while(test_input_handle.no_batch_left() == False): while (batch_id < totalDataLen): batch_id = batch_id + 1 print('test get_batch:') test_ims, fileName = test_input_handle.get_batch(False) test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size) if FLAGS.reverse_input: test_ims_rev = test_dat[:, ::-1] img_gen, ims_watch, ims_rev_watch = model.test( test_dat, test_ims_rev, mask_true, itr) # concat outputs of different gpus along batch img_gen = np.concatenate(img_gen) img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size) ims_watch = np.concatenate(ims_watch) ims_watch = preprocess.reshape_patch_back(ims_watch, FLAGS.patch_size) ims_rev_watch = np.concatenate(ims_rev_watch) ims_rev_watch = preprocess.reshape_patch_back( ims_rev_watch, FLAGS.patch_size) # MSE per frame for i in range(FLAGS.seq_length): x = test_ims[:, i, :, :, 0] # Predict only odd images if FLAGS.gen_num == 5: if (i % 2 == 1): gx = img_gen[:, i // 2, :, :, 0] else: gx = test_ims[:, i, :, :, 0] # Predict 11 images elif FLAGS.gen_num == 11: if (i % 2 == 1): gx = img_gen[:, i, :, :, 0] else: gx = test_ims[:, i, :, :, 0] fmae[i] += metrics.batch_mae_frame_float(gx, x) gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in range(FLAGS.batch_size): sharp[i] += np.max( cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3))) score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True) ssim[i] += score # save prediction examples if batch_id < totalDataLen: path = os.path.join(res_path, str(fileName)) os.mkdir(path) for i in range(FLAGS.seq_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) cv2.imwrite(file_name, img_gt) for i in range(FLAGS.seq_length): name = 'pd' + str(i + 1) + '.png' file_name = os.path.join(path, name) # # Predict only odd images if FLAGS.gen_num == 5: if (i % 2 == 1): img_pd = img_gen[0, i // 2, :, :, :] else: img_pd = test_ims[0, i, :, :, :] # Predict 11 images elif FLAGS.gen_num == 11: img_pd = img_gen[0, i, :, :, :] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) test_input_handle.next() avg_mse = avg_mse / (batch_id * FLAGS.batch_size) print('mse per seq: ' + str(avg_mse)) for i in range(FLAGS.seq_length): print(img_mse[i] / (batch_id * FLAGS.batch_size)) psnr = np.asarray(psnr, dtype=np.float32) / batch_id fmae = np.asarray(fmae, dtype=np.float32) / batch_id ssim = np.asarray(ssim, dtype=np.float32) / (FLAGS.batch_size * batch_id) sharp = np.asarray(sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id) print('psnr per frame: ' + str(np.mean(psnr))) for i in range(FLAGS.seq_length): print(psnr[i]) print('fmae per frame: ' + str(np.mean(fmae))) for i in range(FLAGS.seq_length): print(fmae[i]) print('ssim per frame: ' + str(np.mean(ssim))) for i in range(FLAGS.seq_length): print(ssim[i]) print('sharpness per frame: ' + str(np.mean(sharp))) for i in range(FLAGS.seq_length): print(sharp[i])
def test(model, test_input_handle, configs, save_name, hidden_state, cell_state, hidden_state_diff, cell_state_diff, st_memory, conv_lstm_c, MIMB_oc_w, MIMB_ct_w, MIMN_oc_w, MIMN_ct_w): test_input_handle.begin(do_shuffle=False) res_path = configs.gen_frm_dir if not os.path.isdir(res_path): os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr, fmae, sharp = [], [], [], [], [] for i in range(configs.total_length - configs.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) fmae.append(0) sharp.append(0) if configs.img_height > 0: height = configs.img_height else: height = configs.img_width real_input_flag = np.zeros( (configs.batch_size, configs.total_length - configs.input_length - 1, configs.patch_size**2 * configs.img_channel, configs.img_width // configs.patch_size, height // configs.patch_size)) with torch.no_grad(): while not test_input_handle.no_batch_left(): batch_id = batch_id + 1 if save_name != 'test_result': if batch_id > 100: break test_ims = test_input_handle.get_batch() test_ims = test_ims[:, :configs.total_length] if len(test_ims.shape) > 3: test_dat = preprocess.reshape_patch(test_ims, configs.patch_size) else: test_dat = test_ims # test_dat = np.split(test_dat, configs.n_gpu) # 여기서 debug 바꿔줘야 함 현재 im_gen만 나오게 바껴져 있음 원래는 뭐였는지 살펴보기 test_dat_tensor = torch.tensor(test_dat, device=configs.device, requires_grad=False) img_gen = model.forward(test_dat_tensor, real_input_flag, hidden_state, cell_state, hidden_state_diff, cell_state_diff, st_memory, conv_lstm_c, MIMB_oc_w, MIMB_ct_w, MIMN_oc_w, MIMN_ct_w) img_gen = img_gen.clone().detach().to('cpu').numpy() # concat outputs of different gpus along batch # img_gen = np.concatenate(img_gen) if len(img_gen.shape) > 3: img_gen = preprocess.reshape_patch_back( img_gen, configs.patch_size) # MSE per frame for i in range(configs.total_length - configs.input_length): x = test_ims[:, i + configs.input_length, :, :, :] x = x[:configs.batch_size * configs.n_gpu] x = x - np.where(x > 10000, np.floor_divide(x, 10000) * 10000, np.zeros_like(x)) gx = img_gen[:, i, :, :, :] fmae[i] += metrics.batch_mae_frame_float(gx, x) gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in range(configs.batch_size): sharp[i] += np.max( cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3))) gx_trans = np.transpose(gx[b], (1, 2, 0)) x_trans = np.transpose(x[b], (1, 2, 0)) score = structural_similarity(gx_trans, x_trans, multichannel=True) ssim[i] += score # save prediction examples if batch_id <= configs.num_save_samples: path = os.path.join(res_path, str(save_name)) if not os.path.isdir(path): os.mkdir(path) # if len(debug) != 0: # np.save(os.path.join(path, "f.npy"), debug) for i in range(configs.total_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) if configs.img_channel == 2: img_gt = img_gt[:, :, :1] img_gt = np.transpose(img_gt, (1, 2, 0)) cv2.imwrite(file_name, img_gt) for i in range(configs.total_length - 1): name = 'pd' + str(i) + '.png' file_name = os.path.join(path, name) img_pd = img_gen[0, i, :, :, :] if configs.img_channel == 2: img_pd = img_pd[:, :, :1] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) img_pd = np.transpose(img_pd, (1, 2, 0)) cv2.imwrite(file_name, img_pd) test_input_handle.next() print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...' + str(save_name)) avg_mse = avg_mse / (batch_id * configs.batch_size * configs.n_gpu) print('mse per seq: ' + str(avg_mse)) for i in range(configs.total_length - configs.input_length): print(img_mse[i] / (batch_id * configs.batch_size * configs.n_gpu)) psnr = np.asarray(psnr, dtype=np.float32) / batch_id ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id) fmae = np.asarray(fmae, dtype=np.float32) / batch_id sharp = np.asarray(sharp, dtype=np.float32) / (configs.batch_size * batch_id) print('psnr per frame: ' + str(np.mean(psnr))) print('ssim per frame: ' + str(np.mean(ssim))) print('fmae per frame: ' + str(np.mean(fmae))) print('sharpness per frame: ' + str(np.mean(sharp)))
def main(argv=None): heading_dict = {1: 1, 2:85, 3: 170, 4: 255, 0:0} heading = FLAGS.heading FLAGS.save_dir += FLAGS.dataset_name + str(FLAGS.seq_length) + FLAGS.num_hidden + 'squash' + 'L1+L2+VALID' + 'multi-task' FLAGS.gen_frm_dir += FLAGS.dataset_name if not tf.io.gfile.exists(FLAGS.save_dir): # tf.io.gfile.rmtree(FLAGS.save_dir) tf.io.gfile.makedirs(FLAGS.save_dir) else: FLAGS.pretrained_model = FLAGS.save_dir if not tf.io.gfile.exists(FLAGS.gen_frm_dir): # tf.io.gfile.rmtree(FLAGS.gen_frm_dir) tf.io.gfile.makedirs(FLAGS.gen_frm_dir) train_data_paths = os.path.join(FLAGS.train_data_paths, FLAGS.dataset_name, FLAGS.dataset_name + '_training') valid_data_paths = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name, FLAGS.dataset_name + '_validation') # load data train_input_handle, test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, train_data_paths, valid_data_paths, FLAGS.batch_file, True, FLAGS.input_length, FLAGS.seq_length - FLAGS.input_length) # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287) # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow # lie in UTC+3. utcPlus2 = [30, 69, 126, 186, 234] utcPlus3 = [57, 114, 174, 222, 258] heading_table = np.array([[0, 0], [-1, 1], [1, 1], [-1, -1], [1, -1]], dtype=np.float32) indicies = utcPlus3 if FLAGS.dataset_name == 'Berlin': indicies = utcPlus2 # dims = train_input_handle.dims print("Initializing models", flush=True) model = Model() lr = FLAGS.lr delta = 0.00002 base = 0.99998 eta = 1 for itr in range(1, FLAGS.max_iterations + 1): if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) imss = train_input_handle.get_batch() # # print("imss shape is ", imss.shape) tem_data = imss.copy() heading_image = imss[:, :, :, :, 2]*255 heading_image = (heading_image // 85).astype(np.int8) + 1 heading_image[tem_data[:, :, :, :, 2] == 0] = 0 heading_image = heading_table[heading_image] speed_on_axis = np.expand_dims(imss[:, :, :, :, 1] / np.sqrt(2), axis=-1) imss = speed_on_axis * heading_image imss = preprocess.reshape_patch(imss, FLAGS.patch_size_width, FLAGS.patch_size_height) num_batches = imss.shape[0] for bi in range(0, num_batches, FLAGS.batch_size): ims = imss[bi:bi+FLAGS.batch_size] FLAGS.img_height = ims.shape[2] FLAGS.img_width = ims.shape[3] batch_size = ims.shape[0] if itr < 50000: eta -= delta else: eta = 0.0 random_flip = np.random.random_sample( (batch_size, FLAGS.seq_length-FLAGS.input_length-1)) true_token = (random_flip < eta) #true_token = (random_flip < pow(base,itr)) ones = np.ones((FLAGS.img_height, FLAGS.img_width, int(FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel))) zeros = np.zeros((int(FLAGS.img_height), int(FLAGS.img_width), int(FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel))) mask_true = [] for i in range(batch_size): for j in range(FLAGS.seq_length-FLAGS.input_length-1): if true_token[i,j]: mask_true.append(ones) else: mask_true.append(zeros) mask_true = np.array(mask_true) mask_true = np.reshape(mask_true, (batch_size, FLAGS.seq_length-FLAGS.input_length-1, int(FLAGS.img_height), int(FLAGS.img_width), int(FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel))) cost, _ = model.train(ims, lr, mask_true, batch_size) if FLAGS.reverse_input: ims_rev = ims[:,::-1] cost2, _ = model.train(ims_rev, lr, mask_true, batch_size) cost = (cost + cost2) / 2 if itr % FLAGS.display_interval == 0: print('itr: ' + str(itr), flush=True) print('training loss: ' + str(cost), flush=True) train_input_handle.next() if itr % FLAGS.test_interval == 0: print('test...', flush=True) epsilon = 0.2 batch_size = len(indicies) test_input_handle.begin(do_shuffle = False) # res_path = os.path.join(FLAGS.gen_frm_dir, str(itr)) # os.mkdir(res_path) avg_mse = 0 batch_id = 0 gt_list = [] pred_list = [] pred_list_all = [] pred_vec = [] move_avg = [] img_mse, ssim, psnr, fmae, sharp= [],[],[],[],[] for i in range(FLAGS.seq_length - FLAGS.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) fmae.append(0) sharp.append(0) mask_true = np.zeros((batch_size, FLAGS.seq_length-FLAGS.input_length-1, FLAGS.img_height, FLAGS.img_width, FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel)) while(test_input_handle.no_batch_left() == False): batch_id = batch_id + 1 test_ims = test_input_handle.get_test_batch(indicies) # get the ground truth gt_list.append(test_ims[:, FLAGS.input_length:, :, :, 1:]) # cvt the heading to 0, 1, 2, 3, 4 tem_data = test_ims.copy() heading_image = test_ims[:, :, :, :, 2] * 255 heading_image = (heading_image // 85).astype(np.int8) + 1 heading_image[tem_data[:, :, :, :, 2] == 0] = 0 cvt_heading = heading_image.copy() # convert the data into speed vectors heading_selected = np.zeros_like(heading_image, np.int8) heading_selected[heading_image == heading] = heading heading_image = heading_selected heading_image = heading_table[heading_image] speed_on_axis = np.expand_dims(test_ims[:, :, :, :, 1] / np.sqrt(2), axis=-1) test_ims = speed_on_axis * heading_image # mavg filtered results mavg_results_all = cast_moving_avg(tem_data[:, :FLAGS.input_length, ...]) mavg_results = np.zeros_like(mavg_results_all) # heading_image = np.expand_dims(heading_image, axis=-1) mavg_results[cvt_heading[:, FLAGS.input_length:, ...] == heading] = \ mavg_results_all[cvt_heading[:, FLAGS.input_length:, ...] == heading] move_avg.append(mavg_results) test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size_width, FLAGS.patch_size_height) img_gen = model.test(test_dat, mask_true, batch_size) # concat outputs of different gpus along batch img_gen = np.concatenate(img_gen) # reshape the prediction has ndims=5 img_gen = np.reshape(img_gen, (img_gen.shape[0], FLAGS.seq_length - FLAGS.input_length, FLAGS.img_height, FLAGS.img_width, -1)) img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size_width, FLAGS.patch_size_height) # print("Image Generates Shape is ", img_gen.shape) # MSE per frame img_gen_list = [] img_gen_origin_list = [] for i in range(FLAGS.seq_length - FLAGS.input_length): x = tem_data[:,i + FLAGS.input_length,:,:, 1:] gx = img_gen[:,i,:, :, :] # print("img_gen shape is ", gx.shape) val_results_speed = np.sqrt(gx[..., 0] ** 2 + gx[..., 1] ** 2) # print("val speed: ", val_results_speed, flush=True) val_results_heading = np.zeros_like(gx[..., 1]) val_results_heading[(gx[..., 0] > 0) & (gx[..., 1] > 0)] = 85.0 / 255.0 val_results_heading[(gx[..., 0] > 0) & (gx[..., 1] < 0)] = 255.0 / 255.0 val_results_heading[(gx[..., 0] < 0) & (gx[..., 1] < 0)] = 170.0 / 255.0 val_results_heading[(gx[..., 0] < 0) & (gx[..., 1] > 0)] = 1.0 / 255.0 gen_speed_heading = np.stack([val_results_speed, val_results_heading], axis=-1) img_gen_origin_list.append(gen_speed_heading) # Transformation according to moving average direction when mavg speed is small val_results_heading[mavg_results[:, i, :, :, 1] < epsilon] = \ mavg_results[:, i, :, :, 2][mavg_results[:, i, :, :, 1] < epsilon] gx = np.stack([val_results_speed, val_results_heading], axis=-1) img_gen_list.append(gx) fmae[i] += metrics.batch_mae_frame_float(gx, x) gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse img_gen_list = np.stack(img_gen_list, axis=1) img_gen_origin_list = np.stack(img_gen_origin_list, axis=1) pred_list_all.append(img_gen_origin_list) pred_list.append(img_gen_list) pred_vec.append(img_gen) test_input_handle.next() avg_mse = avg_mse / (batch_id*batch_size*FLAGS.img_height * FLAGS.img_width * FLAGS.patch_size_height * FLAGS.patch_size_width * FLAGS.img_channel * len(img_mse)) print('mse per seq: ' + str(avg_mse), flush=True) for i in range(FLAGS.seq_length - FLAGS.input_length): print(img_mse[i] / (batch_id*batch_size*FLAGS.img_height * FLAGS.img_width * FLAGS.patch_size_height * FLAGS.patch_size_width * FLAGS.img_channel)) gt_list_all = np.stack(gt_list, axis=0) # GT filtered to the direction required gt_list = np.zeros_like(gt_list_all) gt_list[gt_list_all[..., 1]*255 == heading_dict[heading]] = \ gt_list_all[gt_list_all[..., 1]*255 == heading_dict[heading]] pred_list = np.stack(pred_list, axis=0) pred_list_all = np.stack(pred_list_all, axis=0) print("Evaluate on every pixels....") mse = masked_mse_np(pred_list, gt_list, null_val=np.nan) speed_mse = masked_mse_np(pred_list[..., 0], gt_list[..., 0], null_val=np.nan) direction_mse = masked_mse_np(pred_list[..., 1], gt_list[..., 1], null_val=np.nan) print("The output mse is ", mse) print("The speed mse is ", speed_mse) print("The direction mse is ", direction_mse) print("Evaluate on valid pixels for Transformation...") mse = masked_mse_np(pred_list, gt_list, null_val=0.0) speed_mse = masked_mse_np(pred_list[..., 0], gt_list[..., 0], null_val=0.0) direction_mse = masked_mse_np(pred_list[..., 1], gt_list[..., 1], null_val=0.0) print("The output mse is ", mse) print("The speed mse is ", speed_mse) print("The direction mse is ", direction_mse) print("Evaluate on valid pixels for No Transformation...") mse = masked_mse_np(pred_list_all, gt_list, null_val=0.0) speed_mse = masked_mse_np(pred_list_all[..., 0], gt_list[..., 0], null_val=0.0) direction_mse = masked_mse_np(pred_list_all[..., 1], gt_list[..., 1], null_val=0.0) print("The output mse is ", mse) print("The speed mse is ", speed_mse) print("The direction mse is ", direction_mse) print("Evaluate on valid pixels for MAVG...") # Evaluate on large gt speeds for direction move_avg = np.stack(move_avg, axis=0) mse = masked_mse_np(move_avg[..., 1:], gt_list, null_val=0.0) speed_mse = masked_mse_np(move_avg[..., 1], gt_list[..., 0], null_val=0.0) direction_mse = masked_mse_np(move_avg[..., 2], gt_list[..., 1], null_val=0.0) print("The output mse is ", mse) print("The speed mse is ", speed_mse) print("The direction mse is ", direction_mse) large_gt_speed = move_avg[..., 1] >= epsilon move_avg[..., 2][large_gt_speed] = pred_list_all[large_gt_speed, 1] direction_mse = masked_mse_np(move_avg[..., 2], gt_list[..., 1], null_val=0.0) print(f"The direction of combined mavg and large speed~({epsilon}) prediction is ", direction_mse) direction_mse = masked_mse_np(pred_list_all[large_gt_speed, 1], gt_list[large_gt_speed, 1], null_val=0.0) print("The direction mse on large speed gt is ", direction_mse) if itr % FLAGS.snapshot_interval == 0: model.save(itr)
def main(argv=None): if tf.gfile.Exists(FLAGS.save_dir): tf.gfile.DeleteRecursively(FLAGS.save_dir) tf.gfile.MakeDirs(FLAGS.save_dir) if tf.gfile.Exists(FLAGS.gen_frm_dir): tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir) tf.gfile.MakeDirs(FLAGS.gen_frm_dir) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths, FLAGS.batch_size, FLAGS.img_width, FLAGS.seq_length) print("Initializing models") model = Model() lr = FLAGS.lr # Prepare tensorboard logging logger = Logger(os.path.join(FLAGS.gen_frm_dir, 'board'), model.sess) logger.define_item("loss", Logger.Scalar, ()) logger.define_item("lr", Logger.Scalar, ()) logger.define_item("mse", Logger.Scalar, ()) logger.define_item("psnr", Logger.Scalar, ()) logger.define_item("fmae", Logger.Scalar, ()) logger.define_item("ssim", Logger.Scalar, ()) logger.define_item("sharp", Logger.Scalar, ()) logger.define_item( "image", Logger.Image, (1, 2 * FLAGS.img_width, FLAGS.img_width, FLAGS.img_channel), dtype='uint8') for itr in range(1, FLAGS.max_iterations + 1): if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) ims = train_input_handle.get_batch() ims = preprocess.reshape_patch(ims, FLAGS.patch_size) logger.add('lr', lr, itr) cost = model.train(ims, lr) if FLAGS.reverse_input: ims_rev = ims[:, ::-1] cost += model.train(ims_rev, lr, mask_true) cost = cost / 2 logger.add('loss', cost, itr) if itr % FLAGS.display_interval == 0: print('itr: ' + str(itr)) print('training loss: ' + str(cost)) if itr % FLAGS.test_interval == 0: print('test...') test_input_handle.begin(do_shuffle=False) res_path = os.path.join(FLAGS.gen_frm_dir, str(itr)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr, fmae, sharp = [], [], [], [], [] for i in range(FLAGS.seq_length - FLAGS.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) fmae.append(0) sharp.append(0) while (test_input_handle.no_batch_left() == False): batch_id = batch_id + 1 test_ims = test_input_handle.get_batch() test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size) img_gen = model.test(test_dat) # concat outputs of different gpus along batch # img_gen = np.concatenate(img_gen) img_gen = preprocess.reshape_patch_back( img_gen[:, np.newaxis, :, :, :], FLAGS.patch_size) # MSE per frame for i in range(1): x = test_ims[:, -1, :, :, 0] gx = img_gen[:, :, :, 0] fmae[i] += metrics.batch_mae_frame_float(gx, x) gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in range(FLAGS.batch_size): sharp[i] += np.max( cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3))) score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True) ssim[i] += score # save prediction examples if batch_id == 1: sel = np.random.randint(FLAGS.batch_size) img_seq_pd = img_gen[sel] img_seq_gt = test_ims[sel, -1] h, w = img_gen.shape[1:3] out_img = np.zeros((1, h * 2, w * 1, FLAGS.img_channel), dtype='uint8') for i, img_seq in enumerate([img_seq_gt, img_seq_pd]): img = img_seq img = np.maximum(img, 0) img = np.uint8(img * 10) img = np.minimum(img, 255) out_img[0, (i * h):(i * h + h), :] = img logger.add("image", out_img, itr) test_input_handle.next() avg_mse = avg_mse / (batch_id * FLAGS.batch_size) logger.add('mse', avg_mse, itr) print('mse per seq: ' + str(avg_mse)) for i in range(FLAGS.seq_length - FLAGS.input_length): print(img_mse[i] / (batch_id * FLAGS.batch_size)) psnr = np.asarray(psnr, dtype=np.float32) / batch_id fmae = np.asarray(fmae, dtype=np.float32) / batch_id ssim = np.asarray(ssim, dtype=np.float32) / \ (FLAGS.batch_size * batch_id) sharp = np.asarray(sharp, dtype=np.float32) / \ (FLAGS.batch_size * batch_id) print('psnr per frame: ' + str(np.mean(psnr))) logger.add('psnr', np.mean(psnr), itr) for i in range(FLAGS.seq_length - FLAGS.input_length): print(psnr[i]) print('fmae per frame: ' + str(np.mean(fmae))) logger.add('fmae', np.mean(fmae), itr) for i in range(FLAGS.seq_length - FLAGS.input_length): print(fmae[i]) print('ssim per frame: ' + str(np.mean(ssim))) logger.add('ssim', np.mean(ssim), itr) for i in range(FLAGS.seq_length - FLAGS.input_length): print(ssim[i]) print('sharpness per frame: ' + str(np.mean(sharp))) logger.add('sharp', np.mean(sharp), itr) for i in range(FLAGS.seq_length - FLAGS.input_length): print(sharp[i]) if itr % FLAGS.snapshot_interval == 0: model.save(itr) train_input_handle.next()