Ejemplo n.º 1
0
def main(argv=None):

    model = Model()

    while True:
        line = input()
        try:
            inf, outf = line.split(',')
            img = np.array(Level3File(inf).sym_block[0][0]['data'],
                           dtype='float32')
            h, w = img.shape
            nw = FLAGS.img_width
            nh = h * nw // w
            img = cv2.resize(img, (nh, nw), interpolation=cv2.INTER_AREA)
            img = img[np.newaxis, np.newaxis, :, :, np.newaxis]
            img = preprocess.reshape_patch(img, FLAGS.patch_size)

            pred = model.inference(img)
            pred = preprocess.reshape_patch_back(pred[:, np.newaxis, :],
                                                 FLAGS.patch_size)
            pred = cv2.resize(pred[0, 0, :, :, 0], (h, w),
                              interpolation=cv2.INTER_CUBIC)

            imsave(outf, pred, metadata={'axis': 'YX'})
            print('done')

        except Exception as e:
            print('failed:', e)
Ejemplo n.º 2
0
def main(argv=None):

    # FLAGS.save_dir += FLAGS.dataset_name
    # FLAGS.gen_frm_dir += FLAGS.dataset_name
    # if tf.io.gfile.exists(FLAGS.save_dir):
    #     tf.io.gfile.rmtree(FLAGS.save_dir)
    # tf.io.gfile.makedirs(FLAGS.save_dir)
    # if tf.io.gfile.exists(FLAGS.gen_frm_dir):
    #     tf.io.gfile.rmtree(FLAGS.gen_frm_dir)
    # tf.io.gfile.makedirs(FLAGS.gen_frm_dir)

    FLAGS.save_dir += FLAGS.dataset_name + str(
        FLAGS.seq_length) + FLAGS.num_hidden
    print(FLAGS.save_dir)
    # FLAGS.best_model = FLAGS.save_dir + '/best.ckpt'
    FLAGS.best_model = FLAGS.save_dir + f'/best_channels{FLAGS.img_channel}.ckpt'
    # FLAGS.best_model = FLAGS.save_dir + f'/best_channels{FLAGS.img_channel}_weighted.ckpt'
    # FLAGS.save_dir += FLAGS.dataset_name
    FLAGS.pretrained_model = FLAGS.save_dir

    process_data_dir = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name,
                                    'process_0.5')
    node_pos_file_2in1 = os.path.join(process_data_dir, 'node_pos_0.5.npy')
    node_pos = np.load(node_pos_file_2in1)

    test_data_paths = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name,
                                   FLAGS.dataset_name + '_' + FLAGS.mode)
    sub_files = preprocess.list_filenames(test_data_paths, [])

    output_path = f'./Results/predrnn/t{FLAGS.test_time}_{FLAGS.mode}/'
    # output_path = f'./Results/predrnn/t14/'
    preprocess.create_directory_structure(output_path)
    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    print("Initializing models", flush=True)
    model = Model()

    step = 6
    se_total = 0.
    se_1 = 0.
    se_2 = 0.
    se_3 = 0.
    gt_list = []
    pred_list = []
    mavg_list = []
    for f in sub_files:
        with h5py.File(os.path.join(test_data_paths, f), 'r') as h5_file:
            data = h5_file['array'][()]
            # Query the Moving Average Data
            prev_data = [data[y - step:y] for y in indicies]
            prev_data = np.stack(prev_data, axis=0)
            # type casting
            # prev_data = prev_data.astype(np.float32) / 255.0
            # mavg_pred = cast_moving_avg(prev_data)
            # mavg_list.append(mavg_pred)

            # get relevant training data pieces
            data = [
                data[y - FLAGS.input_length:y + FLAGS.seq_length -
                     FLAGS.input_length] for y in indicies
            ]
            data = np.stack(data, axis=0)
            # select the data channel as wished
            data = data[..., :FLAGS.img_channel]

            # all validation data is applied
            # data = np.reshape(data,(-1, FLAGS.seq_length,
            #                     FLAGS.img_height*FLAGS.patch_size_height, FLAGS.img_width*FLAGS.patch_size_width, 3))
            # type casting
            test_dat = data.astype(np.float32) / 255.0
            test_dat = preprocess.reshape_patch(test_dat,
                                                FLAGS.patch_size_width,
                                                FLAGS.patch_size_height)
            batch_size = data.shape[0]
            mask_true = np.zeros(
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_height, FLAGS.img_width, FLAGS.patch_size_height *
                 FLAGS.patch_size_width * FLAGS.img_channel))
            img_gen = model.test(test_dat, mask_true, batch_size)
            # concat outputs of different gpus along batch
            # img_gen = np.concatenate(img_gen)
            img_gen = img_gen[0]
            img_gen = np.maximum(img_gen, 0)
            img_gen = np.minimum(img_gen, 1)
            img_gen = preprocess.reshape_patch_back(img_gen,
                                                    FLAGS.patch_size_width,
                                                    FLAGS.patch_size_height)
            img_gt = data[:, FLAGS.input_length:, ...].astype(
                np.float32) / 255.0

            gt_list.append(img_gt)
            pred_list.append(img_gen)
            se_total += np.sum((img_gt - img_gen)**2)

            se_1 += np.sum((img_gt[..., 0] - img_gen[..., 0])**2)
            se_2 += np.sum((img_gt[..., 1] - img_gen[..., 1])**2)
            # se_3 += np.sum((img_gt[..., 2] - img_gen[..., 2]) ** 2)

            img_gen = np.uint8(img_gen * 255)
            outfile = os.path.join(output_path, FLAGS.dataset_name,
                                   FLAGS.dataset_name + '_test', f)
            preprocess.write_data(img_gen, outfile)

    # mse = se_total / (len(indicies) * len(sub_files) * 495 * 436 * 3 * 3)
    #
    # mse1 = se_1 / (len(indicies) * len(sub_files) * 495 * 436 * 3)
    # mse2 = se_2 / (len(indicies) * len(sub_files) * 495 * 436 * 3)
    # # mse3 = se_3 / (len(indicies) * len(sub_files) * 495 * 436 * 3)
    # print(FLAGS.dataset_name)
    # print("MSE: ", mse)
    # print("MSE_vol: ", mse1)
    # print("MSE_sp: ", mse2)
    # # print("MSE_hd: ", mse3)
    #
    # pred_list = np.stack(pred_list, axis=0)
    # gt_list = np.stack(gt_list, axis=0)
    # mavg_list = np.stack(mavg_list, axis=0)
    #
    # array_mse = masked_mse_np(mavg_list, gt_list, np.nan)
    # print(f'MAVG {step} MSE: ', array_mse)
    #
    # # adapt pred on non_zero mavg pred only
    # pred_list_copy = np.zeros_like(pred_list)
    # pred_list_copy[mavg_list > 0] = pred_list[mavg_list > 0]
    #
    # array_mse = masked_mse_np(pred_list_copy, gt_list, np.nan)
    # print(f'PRED+MAVG {step} MSE: ', array_mse)
    #
    # # Evaluate on nodes
    # # Check MSE on node_pos
    # img_gt_node = gt_list[:, :, :, node_pos[:, 0], node_pos[:, 1], :].astype(np.float32)
    # img_gen_node = pred_list[:, :, :, node_pos[:, 0], node_pos[:, 1], :].astype(np.float32)
    # mse_node_all = masked_mse_np(img_gen_node, img_gt_node, np.nan)
    # mse_node_volume = masked_mse_np(img_gen_node[..., 0], img_gt_node[..., 0], np.nan)
    # mse_node_speed = masked_mse_np(img_gen_node[..., 1], img_gt_node[..., 1], np.nan)
    # mse_node_direction = masked_mse_np(img_gen_node[..., 2], img_gt_node[..., 2], np.nan)
    # print("Results on Node Pos: ")
    # print("MSE: ", mse_node_all)
    # print("Volume mse: ", mse_node_volume)
    # print("Speed mse: ", mse_node_speed)
    # print("Direction mse: ", mse_node_direction)
    #
    # print("Evaluating on Condensed Graph....")
    # seq_length = np.shape(gt_list)[2]
    # img_height = np.shape(gt_list)[3]
    # img_width = np.shape(gt_list)[4]
    # num_channels = np.shape(gt_list)[5]
    # gt_list = np.reshape(gt_list, [-1, seq_length,
    #                             int(img_height / FLAGS.patch_size_height), FLAGS.patch_size_height,
    #                             int(img_width / FLAGS.patch_size_width), FLAGS.patch_size_width,
    #                             num_channels])
    # gt_list = np.transpose(gt_list, [0, 1, 2, 4, 3, 5, 6])
    #
    # pred_list = np.reshape(pred_list, [-1, seq_length,
    #                                int(img_height / FLAGS.patch_size_height), FLAGS.patch_size_height,
    #                                int(img_width / FLAGS.patch_size_width), FLAGS.patch_size_width,
    #                                num_channels])
    # pred_list = np.transpose(pred_list, [0, 1, 2, 4, 3, 5, 6])
    #
    # node_pos = preprocess.construct_road_network_from_grid_condense(FLAGS.patch_size_height, FLAGS.patch_size_width,
    #                                                                 test_data_paths)
    #
    # img_gt_node = gt_list[:, :, node_pos[:, 0], node_pos[:, 1], ...].astype(np.float32)
    # img_gen_node = pred_list[:, :, node_pos[:, 0], node_pos[:, 1], ...].astype(np.float32)
    # mse_node_all = masked_mse_np(img_gen_node, img_gt_node, np.nan)
    # mse_node_volume = masked_mse_np(img_gen_node[..., 0], img_gt_node[..., 0], np.nan)
    # mse_node_speed = masked_mse_np(img_gen_node[..., 1], img_gt_node[..., 1], np.nan)
    # mse_node_direction = masked_mse_np(img_gen_node[..., 2], img_gt_node[..., 2], np.nan)
    # print("MSE: ", mse_node_all)
    # print("Volume mse: ", mse_node_volume)
    # print("Speed mse: ", mse_node_speed)
    # print("Direction mse: ", mse_node_direction)

    print("Finished...")
def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)
    if tf.gfile.Exists(FLAGS.log_dir):
        tf.gfile.DeleteRecursively(FLAGS.log_dir)
    tf.gfile.MakeDirs(FLAGS.log_dir)

    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size, [FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel], FLAGS.seq_length)

    print('Initializing models')
    model = Model()
    lr = FLAGS.lr

    delta = 0.0000125
    base = 0.99998
    eta = 1
    # eta = 0.5
#%%
    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        print('train get_batch:')
        ims, filename = train_input_handle.get_batch(False)
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        if itr < 80000:
            eta -= delta
        else:
            eta = 0.0
        random_flip = np.random.random_sample(
            (FLAGS.batch_size, FLAGS.seq_length))
        true_token = (random_flip < eta)
        ones = np.ones((FLAGS.img_height/FLAGS.patch_size,
                        FLAGS.img_width/FLAGS.patch_size,
                        FLAGS.patch_size**2*FLAGS.img_channel))
        zeros = np.zeros((FLAGS.img_height/FLAGS.patch_size,
                          FLAGS.img_width/FLAGS.patch_size,
                          FLAGS.patch_size**2*FLAGS.img_channel))
        
        mask_true = []
        for i in range(FLAGS.batch_size):
            for j in range(FLAGS.seq_length):
                # 0 2 4 6 8 10
                if (j % 2 == 0):
                    mask_true.append(ones)
                # if iteration bigger it will random mask 1 3 5 7 9
                else:
                    if true_token[i, j]:
                        mask_true.append(ones)
                    else:
                        mask_true.append(zeros)
                    
                    
#                 if j < FLAGS.input_length or FLAGS.seq_length-1-j < FLAGS.input_length:
#                     mask_true.append(ones)
#                 else:
#                     if true_token[i,j-10]:
#                         mask_true.append(ones)
#                     else:
#                         mask_true.append(zeros)
        mask_true = np.array(mask_true)
        mask_true = np.reshape(mask_true, (FLAGS.batch_size,
                                           FLAGS.seq_length,
                                           FLAGS.img_height/FLAGS.patch_size,
                                           FLAGS.img_width/FLAGS.patch_size,
                                           FLAGS.patch_size**2*FLAGS.img_channel))
        ###cost = model.train(ims, lr, mask_true)

        if FLAGS.reverse_input:
            ims_rev = ims[:,::-1]
            ###cost += model.train(ims_rev, lr, mask_true)
            ###cost = cost/2

        cost = model.train(ims, ims_rev, lr, mask_true, itr)
        #tf.summary.scalar('cost', cost)
        

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr))
            print('training loss: ' + str(cost))

        if itr % FLAGS.test_interval == 0:
            print('test...')
            test_input_handle.begin(do_shuffle = False)
            res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse,ssim,psnr,fmae,sharp= [],[],[],[],[]
            for i in range(FLAGS.seq_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)

            mask_true = np.ones((FLAGS.batch_size,
                                    FLAGS.seq_length,
                                    FLAGS.img_height,
                                    FLAGS.img_width,
                                    FLAGS.img_channel))
            for num_batch in range(FLAGS.batch_size):
                for num_seq in range(FLAGS.seq_length):
                    # 0 2 4 6 8 10 skip
                    if (num_seq % 2 == 0):
                        continue
                    # 1 3 5 7 9 replace random noise
                    else:
                        mask_true[num_batch,num_seq] = np.zeros((
                                FLAGS.img_height,
                                FLAGS.img_width,
                                FLAGS.img_channel))
                        
            mask_true = preprocess.reshape_patch(mask_true, FLAGS.patch_size)
            ###while(test_input_handle.no_batch_left() == False):
            while(batch_id <= 10):
                batch_id = batch_id + 1
                print('test get_batch:')
                test_ims, filename = test_input_handle.get_batch(False)
                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)

                if FLAGS.reverse_input:
                    test_ims_rev = test_dat[:,::-1]

                img_gen, ims_watch, ims_rev_watch = model.test(test_dat, test_ims_rev, mask_true, itr)

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size)
                ims_watch = np.concatenate(ims_watch)
                ims_watch = preprocess.reshape_patch_back(ims_watch, FLAGS.patch_size)
                ims_rev_watch = np.concatenate(ims_rev_watch)
                ims_rev_watch = preprocess.reshape_patch_back(ims_rev_watch, FLAGS.patch_size)
                # MSE per frame
                for i in range(FLAGS.seq_length):
                    x = test_ims[:,i,:,:,0]
                    
                    # Predict only odd images
                    if FLAGS.gen_num == 5:
                        if (i % 2 == 1):
                            gx = img_gen[:,i//2,:,:,0]
                        else:
                            gx = test_ims[:,i,:,:,0]
                    # Predict 11 images
                    elif FLAGS.gen_num == 11:
                        if (i % 2 == 1):
                            gx = img_gen[:,i,:,:,0]
                        else:
                            gx = test_ims[:,i,:,:,0]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                    for b in range(FLAGS.batch_size):
                        sharp[i] += np.max(
                            cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b],3)))
                        score, _ = compare_ssim(pred_frm[b],real_frm[b],full=True)
                        ssim[i] += score

                # save prediction examples
                if batch_id <= 10:
                    path = os.path.join(res_path, str(filename))
                    os.mkdir(path)
                    for i in range(FLAGS.seq_length):
                        name = 'gt' + str(i+1) + '.png'
                        file_name = os.path.join(path, name)
                        img_gt = np.uint8(test_ims[0,i,:,:,:] * 255)
                        cv2.imwrite(file_name, img_gt)
                        
                    for i in range(FLAGS.seq_length):
                        name = 'pd' + str(i+1) + '.png'
                        file_name = os.path.join(path, name)
                        
                        # Predict only odd images
                        if FLAGS.gen_num == 5:
                            if (i % 2 == 1):
                                img_pd = img_gen[0,i//2,:,:,:]
                            else:
                                img_pd = test_ims[0,i,:,:,:]
                        # Predict 11 images
                        elif FLAGS.gen_num == 11:
                            img_pd = img_gen[0,i,:,:,:]
                        
                        img_pd = np.maximum(img_pd, 0)
                        img_pd = np.minimum(img_pd, 1)
                        img_pd = np.uint8(img_pd * 255)
                        cv2.imwrite(file_name, img_pd)
                        name = 'zwgt' + str(i+1) + '.png'
                        file_name = os.path.join(path, name)
                        img_zwgt = np.uint8(ims_watch[0,i,:,:,:] * 255)
                        cv2.imwrite(file_name, img_zwgt)
                        name = 'zwgtrev' + str(i+1) + '.png'
                        file_name = os.path.join(path, name)
                        #print('ims_rev_watch shape =',ims_rev_watch.shape)
                        zwgtrev = np.uint8(ims_rev_watch[0,i,:,:,:] * 255)
                        cv2.imwrite(file_name, zwgtrev)
                
                
                test_input_handle.next()
            avg_mse = avg_mse / (batch_id*FLAGS.batch_size)
            print('mse per seq: ' + str(avg_mse))
            for i in range(FLAGS.seq_length):
                print(img_mse[i] / (batch_id*FLAGS.batch_size))
            psnr = np.asarray(psnr, dtype=np.float32)/batch_id
            fmae = np.asarray(fmae, dtype=np.float32)/batch_id
            ssim = np.asarray(ssim, dtype=np.float32)/(FLAGS.batch_size*batch_id)
            sharp = np.asarray(sharp, dtype=np.float32)/(FLAGS.batch_size*batch_id)
            print('psnr per frame: ' + str(np.mean(psnr)))
            for i in range(FLAGS.seq_length):
                print(psnr[i])
            print('fmae per frame: ' + str(np.mean(fmae)))
            for i in range(FLAGS.seq_length):
                print(fmae[i])
            print('ssim per frame: ' + str(np.mean(ssim)))
            for i in range(FLAGS.seq_length):
                print(ssim[i])
            print('sharpness per frame: ' + str(np.mean(sharp)))
            for i in range(FLAGS.seq_length):
                print(sharp[i])

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        train_input_handle.next()
Ejemplo n.º 4
0
def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)

    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size, FLAGS.img_width)

    print('Initializing models')
    model = Model()
    lr = FLAGS.lr

    delta = 0.00002
    base = 0.99998
    eta = 1

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        if itr < 50000:
            eta -= delta
        else:
            eta = 0.0
        random_flip = np.random.random_sample(
            (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1))
        true_token = (random_flip < eta)
        #true_token = (random_flip < pow(base,itr))
        ones = np.ones((FLAGS.img_width / FLAGS.patch_size,
                        FLAGS.img_width / FLAGS.patch_size,
                        FLAGS.patch_size**2 * FLAGS.img_channel))
        zeros = np.zeros((FLAGS.img_width / FLAGS.patch_size,
                          FLAGS.img_width / FLAGS.patch_size,
                          FLAGS.patch_size**2 * FLAGS.img_channel))
        mask_true = []
        for i in range(FLAGS.batch_size):
            for j in range(FLAGS.seq_length - FLAGS.input_length - 1):
                if true_token[i, j]:
                    mask_true.append(ones)
                else:
                    mask_true.append(zeros)
        mask_true = np.array(mask_true)
        mask_true = np.reshape(
            mask_true,
            (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
             FLAGS.img_width / FLAGS.patch_size, FLAGS.img_width /
             FLAGS.patch_size, FLAGS.patch_size**2 * FLAGS.img_channel))
        cost = model.train(ims, lr, mask_true)
        if FLAGS.reverse_input:
            ims_rev = ims[:, ::-1]
            cost += model.train(ims_rev, lr, mask_true)
            cost = cost / 2

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr))
            print('training loss: ' + str(cost))

        if itr % FLAGS.test_interval == 0:
            print('test...')
            test_input_handle.begin(do_shuffle=False)
            res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            mask_true = np.zeros(
                (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_width / FLAGS.patch_size,
                 FLAGS.img_width / FLAGS.patch_size,
                 FLAGS.patch_size**2 * FLAGS.img_channel))
            while (test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_batch()
                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)
                img_gen = model.test(test_dat, mask_true)

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(
                    img_gen, FLAGS.patch_size)
                # MSE per frame
                for i in range(FLAGS.seq_length - FLAGS.input_length):
                    x = test_ims[:, i + FLAGS.input_length, :, :, 0]
                    gx = img_gen[:, i, :, :, 0]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                    for b in range(FLAGS.batch_size):
                        sharp[i] += np.max(
                            cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                        score, _ = compare_ssim(pred_frm[b],
                                                real_frm[b],
                                                full=True)
                        ssim[i] += score

                # save prediction examples
                if batch_id <= 10:
                    path = os.path.join(res_path, str(batch_id))
                    os.mkdir(path)
                    for i in range(FLAGS.seq_length):
                        name = 'gt' + str(i + 1) + '.png'
                        file_name = os.path.join(path, name)
                        img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                        cv2.imwrite(file_name, img_gt)
                    for i in range(FLAGS.seq_length - FLAGS.input_length):
                        name = 'pd' + str(i + 1 + FLAGS.input_length) + '.png'
                        file_name = os.path.join(path, name)
                        img_pd = img_gen[0, i, :, :, :]
                        img_pd = np.maximum(img_pd, 0)
                        img_pd = np.minimum(img_pd, 1)
                        img_pd = np.uint8(img_pd * 255)
                        cv2.imwrite(file_name, img_pd)
                test_input_handle.next()
            avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
            print('mse per seq: ' + str(avg_mse))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] / (batch_id * FLAGS.batch_size))
            psnr = np.asarray(psnr, dtype=np.float32) / batch_id
            fmae = np.asarray(fmae, dtype=np.float32) / batch_id
            ssim = np.asarray(ssim,
                              dtype=np.float32) / (FLAGS.batch_size * batch_id)
            sharp = np.asarray(
                sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id)
            print('psnr per frame: ' + str(np.mean(psnr)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(psnr[i])
            print('fmae per frame: ' + str(np.mean(fmae)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(fmae[i])
            print('ssim per frame: ' + str(np.mean(ssim)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(ssim[i])
            print('sharpness per frame: ' + str(np.mean(sharp)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(sharp[i])

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        train_input_handle.next()
Ejemplo n.º 5
0
def main(argv=None):

    # FLAGS.save_dir += FLAGS.dataset_name
    # FLAGS.gen_frm_dir += FLAGS.dataset_name
    # if tf.io.gfile.exists(FLAGS.save_dir):
    #     tf.io.gfile.rmtree(FLAGS.save_dir)
    # tf.io.gfile.makedirs(FLAGS.save_dir)
    # if tf.io.gfile.exists(FLAGS.gen_frm_dir):
    #     tf.io.gfile.rmtree(FLAGS.gen_frm_dir)
    # tf.io.gfile.makedirs(FLAGS.gen_frm_dir)

    FLAGS.save_dir += FLAGS.dataset_name + str(
        FLAGS.seq_length) + FLAGS.num_hidden
    FLAGS.best_model = FLAGS.save_dir + f'/best_channels{FLAGS.img_channel}_weighted.ckpt'
    FLAGS.gen_frm_dir += FLAGS.dataset_name
    if not tf.io.gfile.exists(FLAGS.save_dir):
        # tf.io.gfile.rmtree(FLAGS.save_dir)
        tf.io.gfile.makedirs(FLAGS.save_dir)
    else:
        FLAGS.pretrained_model = FLAGS.save_dir
    if not tf.io.gfile.exists(FLAGS.gen_frm_dir):
        # tf.io.gfile.rmtree(FLAGS.gen_frm_dir)
        tf.io.gfile.makedirs(FLAGS.gen_frm_dir)

    process_data_dir = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name,
                                    'process_0.5')
    node_pos_file_2in1 = os.path.join(process_data_dir, 'node_pos_0.5.npy')
    node_pos = np.load(node_pos_file_2in1)

    train_data_paths = os.path.join(FLAGS.train_data_paths, FLAGS.dataset_name,
                                    FLAGS.dataset_name + '_training')
    valid_data_paths = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name,
                                    FLAGS.dataset_name + '_validation')
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, train_data_paths, valid_data_paths,
        FLAGS.batch_file, True, FLAGS.input_length,
        FLAGS.seq_length - FLAGS.input_length)

    cities = ['Berlin', 'Istanbul', 'Moscow']
    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    # dims = train_input_handle.dims
    print("Initializing models", flush=True)
    model = Model()
    lr = FLAGS.lr

    delta = 0.2
    base = 0.99998
    eta = 1
    min_val_loss = 1.0

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        imss = train_input_handle.get_batch()
        imss = imss[..., :FLAGS.img_channel]

        imss = preprocess.reshape_patch(imss, FLAGS.patch_size_width,
                                        FLAGS.patch_size_height)
        num_batches = imss.shape[0]
        for bi in range(0, num_batches, FLAGS.batch_size):
            ims = imss[bi:bi + FLAGS.batch_size]
            FLAGS.img_height = ims.shape[2]
            FLAGS.img_width = ims.shape[3]
            batch_size = ims.shape[0]
            if itr < 10:
                eta -= delta
            else:
                eta = 0.0
            random_flip = np.random.random_sample(
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1))
            true_token = (random_flip < eta)
            ones = np.ones((FLAGS.img_height, FLAGS.img_width,
                            int(FLAGS.patch_size_height *
                                FLAGS.patch_size_width * FLAGS.img_channel)))
            zeros = np.zeros((int(FLAGS.img_height), int(FLAGS.img_width),
                              int(FLAGS.patch_size_height *
                                  FLAGS.patch_size_width * FLAGS.img_channel)))
            mask_true = []
            for i in range(batch_size):
                for j in range(FLAGS.seq_length - FLAGS.input_length - 1):
                    if true_token[i, j]:
                        mask_true.append(ones)
                    else:
                        mask_true.append(zeros)
            mask_true = np.array(mask_true)
            mask_true = np.reshape(
                mask_true,
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 int(FLAGS.img_height), int(FLAGS.img_width),
                 int(FLAGS.patch_size_height * FLAGS.patch_size_width *
                     FLAGS.img_channel)))
            cost = model.train(ims, lr, mask_true, batch_size)

            if FLAGS.reverse_input:
                ims_rev = ims[:, ::-1]
                cost += model.train(ims_rev, lr, mask_true, batch_size)
                cost = cost / 2

            # cost = cost / (batch_size * FLAGS.img_height * FLAGS.img_width * FLAGS.patch_size_height *
            #                FLAGS.patch_size_width * FLAGS.img_channel * (FLAGS.seq_length - 1))

            if itr % FLAGS.display_interval == 0:
                print('itr: ' + str(itr), flush=True)
                print('training loss: ' + str(cost), flush=True)

        train_input_handle.next()
        if itr % FLAGS.test_interval == 0:
            print('test...', flush=True)
            batch_size = len(indicies)
            test_input_handle.begin(do_shuffle=False)
            # res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            # os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            mask_true = np.zeros(
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_height, FLAGS.img_width, FLAGS.patch_size_height *
                 FLAGS.patch_size_width * FLAGS.img_channel))

            gt_list = []
            pred_list = []

            while (test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_test_batch(indicies)
                # get the selected channels
                test_ims = test_ims[..., :FLAGS.img_channel]

                gt_list.append(test_ims[:, FLAGS.input_length:, :, :, :])
                test_dat = preprocess.reshape_patch(test_ims,
                                                    FLAGS.patch_size_width,
                                                    FLAGS.patch_size_height)

                img_gen = model.test(test_dat, mask_true, batch_size)

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(
                    img_gen, FLAGS.patch_size_width, FLAGS.patch_size_height)
                pred_list.append(img_gen)
                # MSE per frame
                for i in range(FLAGS.seq_length - FLAGS.input_length):
                    x = test_ims[:, i + FLAGS.input_length, :, :, :]
                    gx = img_gen[:, i, :, :, :]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)

                test_input_handle.next()

            avg_mse = avg_mse / (batch_id * batch_size * FLAGS.img_height *
                                 FLAGS.img_width * FLAGS.patch_size_height *
                                 FLAGS.patch_size_width * FLAGS.img_channel *
                                 len(img_mse))
            print('mse per seq: ' + str(avg_mse), flush=True)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] /
                      (batch_id * batch_size * FLAGS.img_height *
                       FLAGS.img_width * FLAGS.patch_size_height *
                       FLAGS.patch_size_width * FLAGS.img_channel),
                      flush=True)

            gt_list = np.stack(gt_list, axis=0)
            pred_list = np.stack(pred_list, axis=0)
            mse = masked_mse_np(pred_list, gt_list, null_val=np.nan)
            volume_mse = masked_mse_np(pred_list[..., 0],
                                       gt_list[..., 0],
                                       null_val=np.nan)
            speed_mse = masked_mse_np(pred_list[..., 1],
                                      gt_list[..., 1],
                                      null_val=np.nan)

            print("The output mse is ", mse, flush=True)
            print("The volume mse is ", volume_mse, flush=True)
            print("The speed mse is ", speed_mse, flush=True)
            if FLAGS.img_channel == 3:
                direction_mse = masked_mse_np(pred_list[..., 2],
                                              gt_list[..., 2],
                                              null_val=np.nan)
                print("The direction mse is ", direction_mse, flush=True)

            if min_val_loss > mse:
                min_val_loss = mse
                print("Current Min Val Loss is ", min_val_loss)
                model.save_to_best_mode()

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)
Ejemplo n.º 6
0
def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)

    train_data_paths = os.path.join(
        FLAGS.train_data_paths, FLAGS.dataset_name,
        'train_speed_down_sample{}.npz'.format(FLAGS.down_sample))
    valid_data_paths = os.path.join(
        FLAGS.valid_data_paths, FLAGS.dataset_name,
        'valid_speed_down_sample{}.npz'.format(FLAGS.down_sample))
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, train_data_paths, valid_data_paths,
        FLAGS.batch_size, True, FLAGS.down_sample, FLAGS.input_length,
        FLAGS.seq_length - FLAGS.input_length)

    cities = ['Berlin', 'Istanbul', 'Moscow']
    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    dims = train_input_handle.dims
    FLAGS.img_height = dims[-2]
    FLAGS.img_width = dims[-1]
    print("Initializing models", flush=True)
    model = Model()
    lr = FLAGS.lr

    delta = 0.00002
    base = 0.99998
    eta = 1

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        if itr < 50000:
            eta -= delta
        else:
            eta = 0.0
        random_flip = np.random.random_sample(
            (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1))
        true_token = (random_flip < eta)
        #true_token = (random_flip < pow(base,itr))
        ones = np.ones((FLAGS.img_height, FLAGS.img_width,
                        int(FLAGS.patch_size**2 * FLAGS.img_channel)))
        zeros = np.zeros((int(FLAGS.img_height), int(FLAGS.img_width),
                          int(FLAGS.patch_size**2 * FLAGS.img_channel)))
        mask_true = []
        for i in range(FLAGS.batch_size):
            for j in range(FLAGS.seq_length - FLAGS.input_length - 1):
                if true_token[i, j]:
                    mask_true.append(ones)
                else:
                    mask_true.append(zeros)
        mask_true = np.array(mask_true)
        mask_true = np.reshape(
            mask_true,
            (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
             int(FLAGS.img_height), int(FLAGS.img_width),
             int(FLAGS.patch_size**2 * FLAGS.img_channel)))
        cost = model.train(ims, lr, mask_true)
        if FLAGS.reverse_input:
            ims_rev = ims[:, ::-1]
            cost += model.train(ims_rev, lr, mask_true)
            cost = cost / 2

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr), flush=True)
            print('training loss: ' + str(cost), flush=True)

        if itr % FLAGS.test_interval == 0:
            print('test...', flush=True)
            test_input_handle.begin(do_shuffle=False)
            res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            mask_true = np.zeros(
                (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_height, FLAGS.img_width,
                 FLAGS.patch_size**2 * FLAGS.img_channel))
            while (test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_batch()
                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)
                img_gen = model.test(test_dat, mask_true)

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(
                    img_gen, FLAGS.patch_size)
                # MSE per frame
                for i in range(FLAGS.seq_length - FLAGS.input_length):
                    x = test_ims[:, i + FLAGS.input_length, :, :, 0]
                    gx = img_gen[:, i, :, :, 0]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                    for b in range(FLAGS.batch_size):
                        sharp[i] += np.max(
                            cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                        # score, _ = compare_ssim(pred_frm[b],real_frm[b],full=True)
                        # ssim[i] += score

                # save prediction examples
                if batch_id <= 10:
                    path = os.path.join(res_path, str(batch_id))
                    os.mkdir(path)
                    for i in range(FLAGS.seq_length):
                        name = 'gt' + str(i + 1) + '.png'
                        file_name = os.path.join(path, name)
                        img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                        cv2.imwrite(file_name, img_gt)
                    for i in range(FLAGS.seq_length - FLAGS.input_length):
                        name = 'pd' + str(i + 1 + FLAGS.input_length) + '.png'
                        file_name = os.path.join(path, name)
                        img_pd = img_gen[0, i, :, :, :]
                        img_pd = np.maximum(img_pd, 0)
                        img_pd = np.minimum(img_pd, 1)
                        img_pd = np.uint8(img_pd * 255)
                        cv2.imwrite(file_name, img_pd)
                test_input_handle.next()
            avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
            print('mse per seq: ' + str(avg_mse), flush=True)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] / (batch_id * FLAGS.batch_size))
            psnr = np.asarray(psnr, dtype=np.float32) / batch_id
            fmae = np.asarray(fmae, dtype=np.float32) / batch_id
            sharp = np.asarray(
                sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id)
            print('psnr per frame: ' + str(np.mean(psnr)), flush=True)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(psnr[i], flush=True)
            print('fmae per frame: ' + str(np.mean(fmae)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(fmae[i], flush=True)
            print('sharpness per frame: ' + str(np.mean(sharp)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(sharp[i], flush=True)

            # test with file
            valid_data_path = os.path.join(
                FLAGS.train_data_paths, FLAGS.dataset_name,
                '{}_validation'.format(FLAGS.dataset_name))
            files = list_filenames(valid_data_path)
            output_all = []
            labels_all = []
            for f in files:
                valid_file = valid_data_path + '/' + f
                valid_input, raw_output = datasets_factory.test_validation_provider(
                    valid_file,
                    indicies,
                    down_sample=FLAGS.down_sample,
                    seq_len=FLAGS.input_length,
                    horizon=FLAGS.seq_length - FLAGS.input_length)
                valid_input = valid_input.astype(np.float) / 255.0
                labels_all.append(raw_output)
                num_tests = len(indicies)
                num_partitions = int(np.ceil(num_tests / FLAGS.batch_size))
                for i in range(num_partitions):
                    valid_input_i = valid_input[i * FLAGS.batch_size:(i + 1) *
                                                FLAGS.batch_size]
                    num_input_i = valid_input_i.shape[0]
                    if num_input_i < FLAGS.batch_size:
                        zeros_fill_in = np.zeros(
                            (FLAGS.batch_size - num_input_i, FLAGS.seq_length,
                             FLAGS.img_height, FLAGS.img_width,
                             FLAGS.img_channel))
                        valid_input_i = np.concatenate(
                            [valid_input_i, zeros_fill_in], axis=0)
                    img_gen = model.test(valid_input_i, mask_true)
                    output_all.append(img_gen[0][:num_input_i])

            output_all = np.concatenate(output_all, axis=0)
            labels_all = np.concatenate(labels_all, axis=0)
            origin_height = labels_all.shape[-2]
            origin_width = labels_all.shape[-3]
            output_resize = []
            for i in range(output_all.shape[0]):
                output_i = []
                for j in range(output_all.shape[1]):
                    tmp_data = output_all[i, j, 1, :, :]
                    tmp_data = cv2.resize(tmp_data,
                                          (origin_height, origin_width))
                    tmp_data = np.expand_dims(tmp_data, axis=0)
                    output_i.append(tmp_data)
                output_i = np.stack(output_i, axis=0)
                output_resize.append(output_i)
            output_resize = np.stack(output_resize, axis=0)

            output_resize *= 255.0
            labels_all = np.expand_dims(labels_all[..., 1], axis=2)
            valid_mse = masked_mse_np(output_resize, labels_all, np.nan)

            print("validation mse is ", valid_mse, flush=True)

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        train_input_handle.next()
Ejemplo n.º 7
0
def test(model, test_input_handle, configs, itr=None):
    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...')
    res_path = os.path.join(configs.gen_frm_dir, str(itr))
    os.mkdir(res_path)
    avg_mse = 0
    batch_id = 0
    img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []

    output_length = configs.total_length - configs.input_length  #20-10
    for i in range(configs.total_length - configs.input_length):
        img_mse.append(0)
        ssim.append(0)
        psnr.append(0)
        fmae.append(0)
        sharp.append(0)

    real_input_flag = np.zeros(  #(4 , 20-10-1 , 140//4 , 140//4 , 2^2*1)
        (configs.batch_size, configs.total_length - configs.input_length - 1,
         configs.img_width // configs.patch_size,
         configs.img_width // configs.patch_size,
         configs.patch_size**2 * configs.img_channel))

    for ind, test_input in enumerate(test_input_handle):
        test_ims = test_input.numpy(
        )  # test_ims shape: (batch, seq, channels, height, width)
        test_ims = np.transpose(test_ims, (0, 1, 3, 4, 2))
        batch_id = batch_id + 1

        test_dat = preprocess.reshape_patch(test_input, configs.patch_size)
        img_gen = model.test(test_dat, real_input_flag)

        img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size)
        img_out = img_gen[:, -output_length:]
        # MSE per frame
        for i in range(output_length):
            x = test_ims[:, i + configs.input_length, :, :, :]
            gx = img_out[:, i, :, :, :]
            fmae[i] += metrics.batch_mae_frame_float(gx, x)
            gx = np.maximum(gx, 0)
            gx = np.minimum(gx, 1)
            mse = np.square(x - gx).sum()
            img_mse[i] += mse
            avg_mse += mse

            real_frm = np.uint8(x * 255)
            pred_frm = np.uint8(gx * 255)
            psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
            for b in range(configs.batch_size):
                score, _ = compare_ssim(pred_frm[b],
                                        real_frm[b],
                                        full=True,
                                        multichannel=True)
                ssim[i] += score
                sharp[i] += np.max(
                    cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))

        # save prediction examples
        if batch_id <= configs.num_save_samples:
            path = os.path.join(res_path, str(batch_id))
            os.mkdir(path)
            for i in range(configs.total_length):
                name = 'gt' + str(i + 1) + '.png'
                file_name = os.path.join(path, name)
                img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                cv2.imwrite(file_name, img_gt)
            for i in range(output_length):
                name = 'pd' + str(i + 1 + configs.input_length) + '.png'
                file_name = os.path.join(path, name)
                img_pd = img_out[0, i, :, :, :]
                img_pd = np.maximum(img_pd, 0)
                img_pd = np.minimum(img_pd, 1)
                img_pd = np.uint8(img_pd * 255)
                cv2.imwrite(file_name, img_pd)

    avg_mse = avg_mse / (batch_id * configs.batch_size)
    print('mse per seq: ' + str(avg_mse))
    for i in range(configs.total_length - configs.input_length):
        print(img_mse[i] / (batch_id * configs.batch_size))

    ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id)
    psnr = np.asarray(psnr, dtype=np.float32) / batch_id
    fmae = np.asarray(fmae, dtype=np.float32) / batch_id
    sharp = np.asarray(sharp,
                       dtype=np.float32) / (configs.batch_size * batch_id)
    print('ssim per frame: ' + str(np.mean(ssim)))
    for i in range(configs.total_length - configs.input_length):
        print(ssim[i])
    print('psnr per frame: ' + str(np.mean(psnr)))
    for i in range(configs.total_length - configs.input_length):
        print(psnr[i])
    print('fmae per frame: ' + str(np.mean(fmae)))
    for i in range(configs.total_length - configs.input_length):
        print(fmae[i])
    print('sharpness per frame: ' + str(np.mean(sharp)))
    for i in range(configs.total_length - configs.input_length):
        print(sharp[i])

    return avg_mse, ssim, psnr, fmae, sharp
Ejemplo n.º 8
0
def main(argv=None):

    tf.disable_eager_execution()  #toegevoegd anders error

    # load data
    _, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.test_data_paths,
        FLAGS.batch_size, FLAGS.img_width)

    print("Initializing models")
    model = Model()
    lr = FLAGS.lr

    print('test...')
    test_input_handle.begin(do_shuffle=False)
    res_path = os.path.join(FLAGS.gen_frm_dir, 'test')
    os.mkdir(res_path)
    avg_mse = 0
    batch_id = 0
    img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        img_mse.append(0)
        ssim.append(0)
        psnr.append(0)
        fmae.append(0)
        sharp.append(0)
    mask_true = np.zeros(
        (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
         int(FLAGS.img_height / FLAGS.patch_size),
         int(FLAGS.img_width / FLAGS.patch_size),
         FLAGS.patch_size**2 * FLAGS.img_channel))
    while (test_input_handle.no_batch_left() == False):
        batch_id = batch_id + 1
        test_ims = test_input_handle.get_batch()
        test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)
        img_gen = model.test(test_dat, mask_true)

        # concat outputs of different gpus along batch
        img_gen = np.concatenate(img_gen)
        img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size)
        # MSE per frame
        for i in xrange(FLAGS.seq_length - FLAGS.input_length):
            x = test_ims[:, i + FLAGS.input_length, :, :, 0]
            gx = img_gen[:, i, :, :, 0]
            fmae[i] += metrics.batch_mae_frame_float(gx, x)
            gx = np.maximum(gx, 0)
            gx = np.minimum(gx, 1)
            mse = np.square(x - gx).sum()
            img_mse[i] += mse
            avg_mse += mse

            real_frm = np.uint8(x * 255)
            pred_frm = np.uint8(gx * 255)
            psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
            for b in xrange(FLAGS.batch_size):
                sharp[i] += np.max(
                    cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True)
                ssim[i] += score

        # save prediction examples
        if batch_id <= 10:
            path = os.path.join(res_path, str(batch_id))
            os.mkdir(path)
            for i in xrange(FLAGS.seq_length):
                name = 'gt' + str(i + 1) + '.png'
                file_name = os.path.join(path, name)
                img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                cv2.imwrite(file_name, img_gt)
            for i in xrange(FLAGS.seq_length - FLAGS.input_length):
                name = 'pd' + str(i + 1 + FLAGS.input_length) + '.png'
                file_name = os.path.join(path, name)
                img_pd = img_gen[0, i, :, :, :]
                img_pd = np.maximum(img_pd, 0)
                img_pd = np.minimum(img_pd, 1)
                img_pd = np.uint8(img_pd * 255)
                cv2.imwrite(file_name, img_pd)
        test_input_handle.next()
    avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
    print('mse per seq: ' + str(avg_mse))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(img_mse[i] / (batch_id * FLAGS.batch_size))
    psnr = np.asarray(psnr, dtype=np.float32) / batch_id
    fmae = np.asarray(fmae, dtype=np.float32) / batch_id
    ssim = np.asarray(ssim, dtype=np.float32) / (FLAGS.batch_size * batch_id)
    sharp = np.asarray(sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id)
    print('psnr per frame: ' + str(np.mean(psnr)))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(psnr[i])
    print('fmae per frame: ' + str(np.mean(fmae)))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(fmae[i])
    print('ssim per frame: ' + str(np.mean(ssim)))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(ssim[i])
    print('sharpness per frame: ' + str(np.mean(sharp)))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(sharp[i])
def main(argv=None):
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)
    if tf.gfile.Exists(FLAGS.log_dir):
        tf.gfile.DeleteRecursively(FLAGS.log_dir)
    tf.gfile.MakeDirs(FLAGS.log_dir)

    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size,
        [FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel],
        FLAGS.seq_length)

    print('Initializing models')
    model = Model()

    #%%
    test_input_handle.begin(do_shuffle=False)
    totalDataLen = int(test_input_handle.total())
    print('totalDataLen=', totalDataLen)
    for itr in range(1):
        print('inference...')
        res_path = os.path.join(FLAGS.gen_frm_dir, 'images' + str(itr))
        os.mkdir(res_path)
        avg_mse = 0
        batch_id = 0
        img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
        for i in range(FLAGS.seq_length):
            img_mse.append(0)
            ssim.append(0)
            psnr.append(0)
            fmae.append(0)
            sharp.append(0)

        mask_true = np.ones(
            (FLAGS.batch_size, FLAGS.seq_length, FLAGS.img_height,
             FLAGS.img_width, FLAGS.img_channel))
        for num_batch in range(FLAGS.batch_size):
            for num_seq in range(FLAGS.seq_length):
                # 0 2 4 6 8 10 skip
                if (num_seq % 2 == 0):
                    continue
                # 1 3 5 7 9 replace random noise
                else:
                    mask_true[num_batch, num_seq] = np.zeros(
                        (FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel))

        mask_true = preprocess.reshape_patch(mask_true, FLAGS.patch_size)
        ###while(test_input_handle.no_batch_left() == False):
        while (batch_id < totalDataLen):
            batch_id = batch_id + 1
            print('test get_batch:')
            test_ims, fileName = test_input_handle.get_batch(False)
            test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)

            if FLAGS.reverse_input:
                test_ims_rev = test_dat[:, ::-1]

            img_gen, ims_watch, ims_rev_watch = model.test(
                test_dat, test_ims_rev, mask_true, itr)

            # concat outputs of different gpus along batch
            img_gen = np.concatenate(img_gen)
            img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size)
            ims_watch = np.concatenate(ims_watch)
            ims_watch = preprocess.reshape_patch_back(ims_watch,
                                                      FLAGS.patch_size)
            ims_rev_watch = np.concatenate(ims_rev_watch)
            ims_rev_watch = preprocess.reshape_patch_back(
                ims_rev_watch, FLAGS.patch_size)
            # MSE per frame
            for i in range(FLAGS.seq_length):
                x = test_ims[:, i, :, :, 0]

                # Predict only odd images
                if FLAGS.gen_num == 5:
                    if (i % 2 == 1):
                        gx = img_gen[:, i // 2, :, :, 0]
                    else:
                        gx = test_ims[:, i, :, :, 0]
                # Predict 11 images
                elif FLAGS.gen_num == 11:
                    if (i % 2 == 1):
                        gx = img_gen[:, i, :, :, 0]
                    else:
                        gx = test_ims[:, i, :, :, 0]

                fmae[i] += metrics.batch_mae_frame_float(gx, x)
                gx = np.maximum(gx, 0)
                gx = np.minimum(gx, 1)
                mse = np.square(x - gx).sum()
                img_mse[i] += mse
                avg_mse += mse

                real_frm = np.uint8(x * 255)
                pred_frm = np.uint8(gx * 255)
                psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                for b in range(FLAGS.batch_size):
                    sharp[i] += np.max(
                        cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                    score, _ = compare_ssim(pred_frm[b],
                                            real_frm[b],
                                            full=True)
                    ssim[i] += score

            # save prediction examples
            if batch_id < totalDataLen:
                path = os.path.join(res_path, str(fileName))
                os.mkdir(path)
                for i in range(FLAGS.seq_length):
                    name = 'gt' + str(i + 1) + '.png'
                    file_name = os.path.join(path, name)
                    img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                    cv2.imwrite(file_name, img_gt)

                for i in range(FLAGS.seq_length):
                    name = 'pd' + str(i + 1) + '.png'
                    file_name = os.path.join(path, name)

                    #                   # Predict only odd images
                    if FLAGS.gen_num == 5:
                        if (i % 2 == 1):
                            img_pd = img_gen[0, i // 2, :, :, :]
                        else:
                            img_pd = test_ims[0, i, :, :, :]
                    # Predict 11 images
                    elif FLAGS.gen_num == 11:
                        img_pd = img_gen[0, i, :, :, :]

                    img_pd = np.maximum(img_pd, 0)
                    img_pd = np.minimum(img_pd, 1)
                    img_pd = np.uint8(img_pd * 255)
                    cv2.imwrite(file_name, img_pd)

            test_input_handle.next()
        avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
        print('mse per seq: ' + str(avg_mse))
        for i in range(FLAGS.seq_length):
            print(img_mse[i] / (batch_id * FLAGS.batch_size))
        psnr = np.asarray(psnr, dtype=np.float32) / batch_id
        fmae = np.asarray(fmae, dtype=np.float32) / batch_id
        ssim = np.asarray(ssim,
                          dtype=np.float32) / (FLAGS.batch_size * batch_id)
        sharp = np.asarray(sharp,
                           dtype=np.float32) / (FLAGS.batch_size * batch_id)
        print('psnr per frame: ' + str(np.mean(psnr)))
        for i in range(FLAGS.seq_length):
            print(psnr[i])
        print('fmae per frame: ' + str(np.mean(fmae)))
        for i in range(FLAGS.seq_length):
            print(fmae[i])
        print('ssim per frame: ' + str(np.mean(ssim)))
        for i in range(FLAGS.seq_length):
            print(ssim[i])
        print('sharpness per frame: ' + str(np.mean(sharp)))
        for i in range(FLAGS.seq_length):
            print(sharp[i])
Ejemplo n.º 10
0
def test(model, test_input_handle, configs, save_name, hidden_state,
         cell_state, hidden_state_diff, cell_state_diff, st_memory,
         conv_lstm_c, MIMB_oc_w, MIMB_ct_w, MIMN_oc_w, MIMN_ct_w):
    test_input_handle.begin(do_shuffle=False)
    res_path = configs.gen_frm_dir
    if not os.path.isdir(res_path):
        os.mkdir(res_path)
    avg_mse = 0
    batch_id = 0
    img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []

    for i in range(configs.total_length - configs.input_length):
        img_mse.append(0)
        ssim.append(0)
        psnr.append(0)
        fmae.append(0)
        sharp.append(0)

    if configs.img_height > 0:
        height = configs.img_height
    else:
        height = configs.img_width

    real_input_flag = np.zeros(
        (configs.batch_size, configs.total_length - configs.input_length - 1,
         configs.patch_size**2 * configs.img_channel,
         configs.img_width // configs.patch_size,
         height // configs.patch_size))

    with torch.no_grad():
        while not test_input_handle.no_batch_left():
            batch_id = batch_id + 1
            if save_name != 'test_result':
                if batch_id > 100: break

            test_ims = test_input_handle.get_batch()
            test_ims = test_ims[:, :configs.total_length]

            if len(test_ims.shape) > 3:
                test_dat = preprocess.reshape_patch(test_ims,
                                                    configs.patch_size)
            else:
                test_dat = test_ims

            # test_dat = np.split(test_dat, configs.n_gpu)
            # 여기서 debug 바꿔줘야 함 현재 im_gen만 나오게 바껴져 있음 원래는 뭐였는지 살펴보기
            test_dat_tensor = torch.tensor(test_dat,
                                           device=configs.device,
                                           requires_grad=False)
            img_gen = model.forward(test_dat_tensor, real_input_flag,
                                    hidden_state, cell_state,
                                    hidden_state_diff, cell_state_diff,
                                    st_memory, conv_lstm_c, MIMB_oc_w,
                                    MIMB_ct_w, MIMN_oc_w, MIMN_ct_w)
            img_gen = img_gen.clone().detach().to('cpu').numpy()

            # concat outputs of different gpus along batch
            # img_gen = np.concatenate(img_gen)
            if len(img_gen.shape) > 3:
                img_gen = preprocess.reshape_patch_back(
                    img_gen, configs.patch_size)

            # MSE per frame
            for i in range(configs.total_length - configs.input_length):
                x = test_ims[:, i + configs.input_length, :, :, :]
                x = x[:configs.batch_size * configs.n_gpu]
                x = x - np.where(x > 10000,
                                 np.floor_divide(x, 10000) * 10000,
                                 np.zeros_like(x))
                gx = img_gen[:, i, :, :, :]
                fmae[i] += metrics.batch_mae_frame_float(gx, x)
                gx = np.maximum(gx, 0)
                gx = np.minimum(gx, 1)
                mse = np.square(x - gx).sum()
                img_mse[i] += mse
                avg_mse += mse
                real_frm = np.uint8(x * 255)
                pred_frm = np.uint8(gx * 255)
                psnr[i] += metrics.batch_psnr(pred_frm, real_frm)

                for b in range(configs.batch_size):
                    sharp[i] += np.max(
                        cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                    gx_trans = np.transpose(gx[b], (1, 2, 0))
                    x_trans = np.transpose(x[b], (1, 2, 0))
                    score = structural_similarity(gx_trans,
                                                  x_trans,
                                                  multichannel=True)
                    ssim[i] += score

            # save prediction examples
            if batch_id <= configs.num_save_samples:
                path = os.path.join(res_path, str(save_name))
                if not os.path.isdir(path):
                    os.mkdir(path)

                # if len(debug) != 0:
                #     np.save(os.path.join(path, "f.npy"), debug)

                for i in range(configs.total_length):
                    name = 'gt' + str(i + 1) + '.png'
                    file_name = os.path.join(path, name)
                    img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                    if configs.img_channel == 2:
                        img_gt = img_gt[:, :, :1]
                    img_gt = np.transpose(img_gt, (1, 2, 0))
                    cv2.imwrite(file_name, img_gt)

                for i in range(configs.total_length - 1):
                    name = 'pd' + str(i) + '.png'
                    file_name = os.path.join(path, name)
                    img_pd = img_gen[0, i, :, :, :]
                    if configs.img_channel == 2:
                        img_pd = img_pd[:, :, :1]
                    img_pd = np.maximum(img_pd, 0)
                    img_pd = np.minimum(img_pd, 1)
                    img_pd = np.uint8(img_pd * 255)
                    img_pd = np.transpose(img_pd, (1, 2, 0))
                    cv2.imwrite(file_name, img_pd)
            test_input_handle.next()

    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
          'test...' + str(save_name))

    avg_mse = avg_mse / (batch_id * configs.batch_size * configs.n_gpu)
    print('mse per seq: ' + str(avg_mse))

    for i in range(configs.total_length - configs.input_length):
        print(img_mse[i] / (batch_id * configs.batch_size * configs.n_gpu))

    psnr = np.asarray(psnr, dtype=np.float32) / batch_id
    ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id)
    fmae = np.asarray(fmae, dtype=np.float32) / batch_id
    sharp = np.asarray(sharp,
                       dtype=np.float32) / (configs.batch_size * batch_id)

    print('psnr per frame: ' + str(np.mean(psnr)))
    print('ssim per frame: ' + str(np.mean(ssim)))
    print('fmae per frame: ' + str(np.mean(fmae)))
    print('sharpness per frame: ' + str(np.mean(sharp)))
def main(argv=None):

    heading_dict = {1: 1, 2:85, 3: 170, 4: 255, 0:0}
    heading = FLAGS.heading
    FLAGS.save_dir += FLAGS.dataset_name + str(FLAGS.seq_length) + FLAGS.num_hidden + 'squash' + 'L1+L2+VALID' + 'multi-task'
    FLAGS.gen_frm_dir += FLAGS.dataset_name
    if not tf.io.gfile.exists(FLAGS.save_dir):
        # tf.io.gfile.rmtree(FLAGS.save_dir)
        tf.io.gfile.makedirs(FLAGS.save_dir)
    else:
        FLAGS.pretrained_model = FLAGS.save_dir
    if not tf.io.gfile.exists(FLAGS.gen_frm_dir):
        # tf.io.gfile.rmtree(FLAGS.gen_frm_dir)
        tf.io.gfile.makedirs(FLAGS.gen_frm_dir)

    train_data_paths = os.path.join(FLAGS.train_data_paths, FLAGS.dataset_name, FLAGS.dataset_name + '_training')
    valid_data_paths = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name, FLAGS.dataset_name + '_validation')
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, train_data_paths, valid_data_paths,
        FLAGS.batch_file, True, FLAGS.input_length, FLAGS.seq_length - FLAGS.input_length)

    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    heading_table = np.array([[0, 0], [-1, 1], [1, 1], [-1, -1], [1, -1]], dtype=np.float32)

    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    # dims = train_input_handle.dims
    print("Initializing models", flush=True)
    model = Model()
    lr = FLAGS.lr

    delta = 0.00002
    base = 0.99998
    eta = 1

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        imss = train_input_handle.get_batch()

        # # print("imss shape is ", imss.shape)
        tem_data = imss.copy()
        heading_image = imss[:, :, :, :, 2]*255
        heading_image = (heading_image // 85).astype(np.int8) + 1
        heading_image[tem_data[:, :, :, :, 2] == 0] = 0
        heading_image = heading_table[heading_image]

        speed_on_axis = np.expand_dims(imss[:, :, :, :, 1] / np.sqrt(2), axis=-1)
        imss = speed_on_axis * heading_image

        imss = preprocess.reshape_patch(imss, FLAGS.patch_size_width, FLAGS.patch_size_height)
        num_batches = imss.shape[0]
        for bi in range(0, num_batches, FLAGS.batch_size):
            ims = imss[bi:bi+FLAGS.batch_size]
            FLAGS.img_height = ims.shape[2]
            FLAGS.img_width = ims.shape[3]
            batch_size = ims.shape[0]
            if itr < 50000:
                eta -= delta
            else:
                eta = 0.0
            random_flip = np.random.random_sample(
                (batch_size, FLAGS.seq_length-FLAGS.input_length-1))
            true_token = (random_flip < eta)
            #true_token = (random_flip < pow(base,itr))
            ones = np.ones((FLAGS.img_height,
                            FLAGS.img_width,
                            int(FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel)))
            zeros = np.zeros((int(FLAGS.img_height),
                              int(FLAGS.img_width),
                              int(FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel)))
            mask_true = []
            for i in range(batch_size):
                for j in range(FLAGS.seq_length-FLAGS.input_length-1):
                    if true_token[i,j]:
                        mask_true.append(ones)
                    else:
                        mask_true.append(zeros)
            mask_true = np.array(mask_true)
            mask_true = np.reshape(mask_true, (batch_size,
                                               FLAGS.seq_length-FLAGS.input_length-1,
                                               int(FLAGS.img_height),
                                               int(FLAGS.img_width),
                                               int(FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel)))
            cost, _ = model.train(ims, lr, mask_true, batch_size)

            if FLAGS.reverse_input:
                ims_rev = ims[:,::-1]
                cost2, _ = model.train(ims_rev, lr, mask_true, batch_size)
                cost = (cost + cost2) / 2

            if itr % FLAGS.display_interval == 0:
                print('itr: ' + str(itr), flush=True)
                print('training loss: ' + str(cost), flush=True)

        train_input_handle.next()
        if itr % FLAGS.test_interval == 0:
            print('test...', flush=True)
            epsilon = 0.2
            batch_size = len(indicies)
            test_input_handle.begin(do_shuffle = False)
            # res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            # os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            gt_list = []
            pred_list = []
            pred_list_all = []
            pred_vec = []
            move_avg = []
            img_mse, ssim, psnr, fmae, sharp= [],[],[],[],[]
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            mask_true = np.zeros((batch_size,
                                  FLAGS.seq_length-FLAGS.input_length-1,
                                  FLAGS.img_height,
                                  FLAGS.img_width,
                                  FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel))
            while(test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_test_batch(indicies)
                # get the ground truth
                gt_list.append(test_ims[:, FLAGS.input_length:, :, :, 1:])
                # cvt the heading to 0, 1, 2, 3, 4
                tem_data = test_ims.copy()
                heading_image = test_ims[:, :, :, :, 2] * 255
                heading_image = (heading_image // 85).astype(np.int8) + 1
                heading_image[tem_data[:, :, :, :, 2] == 0] = 0
                cvt_heading = heading_image.copy()

                # convert the data into speed vectors
                heading_selected = np.zeros_like(heading_image, np.int8)
                heading_selected[heading_image == heading] = heading
                heading_image = heading_selected
                heading_image = heading_table[heading_image]
                speed_on_axis = np.expand_dims(test_ims[:, :, :, :, 1] / np.sqrt(2), axis=-1)
                test_ims = speed_on_axis * heading_image

                # mavg filtered results
                mavg_results_all = cast_moving_avg(tem_data[:, :FLAGS.input_length, ...])
                mavg_results = np.zeros_like(mavg_results_all)
                # heading_image = np.expand_dims(heading_image, axis=-1)
                mavg_results[cvt_heading[:, FLAGS.input_length:, ...] == heading] = \
                    mavg_results_all[cvt_heading[:, FLAGS.input_length:, ...] == heading]
                move_avg.append(mavg_results)

                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size_width, FLAGS.patch_size_height)
                img_gen = model.test(test_dat, mask_true, batch_size)
                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                # reshape the prediction has ndims=5
                img_gen = np.reshape(img_gen, (img_gen.shape[0], FLAGS.seq_length - FLAGS.input_length,
                                               FLAGS.img_height, FLAGS.img_width, -1))
                img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size_width, FLAGS.patch_size_height)
                # print("Image Generates Shape is ", img_gen.shape)
                # MSE per frame

                img_gen_list = []
                img_gen_origin_list = []
                for i in range(FLAGS.seq_length - FLAGS.input_length):
                    x = tem_data[:,i + FLAGS.input_length,:,:, 1:]
                    gx = img_gen[:,i,:, :, :]

                    # print("img_gen shape is ", gx.shape)
                    val_results_speed = np.sqrt(gx[..., 0] ** 2 + gx[..., 1] ** 2)
                    # print("val speed: ", val_results_speed, flush=True)
                    val_results_heading = np.zeros_like(gx[..., 1])
                    val_results_heading[(gx[..., 0] > 0) & (gx[..., 1] > 0)] = 85.0 / 255.0
                    val_results_heading[(gx[..., 0] > 0) & (gx[..., 1] < 0)] = 255.0 / 255.0
                    val_results_heading[(gx[..., 0] < 0) & (gx[..., 1] < 0)] = 170.0 / 255.0
                    val_results_heading[(gx[..., 0] < 0) & (gx[..., 1] > 0)] = 1.0 / 255.0

                    gen_speed_heading = np.stack([val_results_speed, val_results_heading], axis=-1)
                    img_gen_origin_list.append(gen_speed_heading)

                    # Transformation according to moving average direction when mavg speed is small
                    val_results_heading[mavg_results[:, i, :, :, 1] < epsilon] = \
                        mavg_results[:, i, :, :, 2][mavg_results[:, i, :, :, 1] < epsilon]
                    gx = np.stack([val_results_speed, val_results_heading], axis=-1)
                    img_gen_list.append(gx)

                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                img_gen_list = np.stack(img_gen_list, axis=1)
                img_gen_origin_list = np.stack(img_gen_origin_list, axis=1)
                pred_list_all.append(img_gen_origin_list)
                pred_list.append(img_gen_list)
                pred_vec.append(img_gen)
                test_input_handle.next()

            avg_mse = avg_mse / (batch_id*batch_size*FLAGS.img_height *
                                 FLAGS.img_width * FLAGS.patch_size_height *
                                 FLAGS.patch_size_width * FLAGS.img_channel * len(img_mse))
            print('mse per seq: ' + str(avg_mse), flush=True)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] / (batch_id*batch_size*FLAGS.img_height *
                                 FLAGS.img_width * FLAGS.patch_size_height *
                                 FLAGS.patch_size_width * FLAGS.img_channel))

            gt_list_all = np.stack(gt_list, axis=0)
            # GT filtered to the direction required
            gt_list = np.zeros_like(gt_list_all)
            gt_list[gt_list_all[..., 1]*255 == heading_dict[heading]] = \
                gt_list_all[gt_list_all[..., 1]*255 == heading_dict[heading]]

            pred_list = np.stack(pred_list, axis=0)
            pred_list_all = np.stack(pred_list_all, axis=0)

            print("Evaluate on every pixels....")
            mse = masked_mse_np(pred_list, gt_list, null_val=np.nan)
            speed_mse = masked_mse_np(pred_list[..., 0], gt_list[..., 0], null_val=np.nan)
            direction_mse = masked_mse_np(pred_list[..., 1], gt_list[..., 1], null_val=np.nan)
            print("The output mse is ", mse)
            print("The speed mse is ", speed_mse)
            print("The direction mse is ", direction_mse)

            print("Evaluate on valid pixels for Transformation...")
            mse = masked_mse_np(pred_list, gt_list, null_val=0.0)
            speed_mse = masked_mse_np(pred_list[..., 0], gt_list[..., 0], null_val=0.0)
            direction_mse = masked_mse_np(pred_list[..., 1], gt_list[..., 1], null_val=0.0)
            print("The output mse is ", mse)
            print("The speed mse is ", speed_mse)
            print("The direction mse is ", direction_mse)

            print("Evaluate on valid pixels for No Transformation...")
            mse = masked_mse_np(pred_list_all, gt_list, null_val=0.0)
            speed_mse = masked_mse_np(pred_list_all[..., 0], gt_list[..., 0], null_val=0.0)
            direction_mse = masked_mse_np(pred_list_all[..., 1], gt_list[..., 1], null_val=0.0)
            print("The output mse is ", mse)
            print("The speed mse is ", speed_mse)
            print("The direction mse is ", direction_mse)

            print("Evaluate on valid pixels for MAVG...")
            # Evaluate on large gt speeds for direction
            move_avg = np.stack(move_avg, axis=0)
            mse = masked_mse_np(move_avg[..., 1:], gt_list, null_val=0.0)
            speed_mse = masked_mse_np(move_avg[..., 1], gt_list[..., 0], null_val=0.0)
            direction_mse = masked_mse_np(move_avg[..., 2], gt_list[..., 1], null_val=0.0)
            print("The output mse is ", mse)
            print("The speed mse is ", speed_mse)
            print("The direction mse is ", direction_mse)

            large_gt_speed = move_avg[..., 1] >= epsilon
            move_avg[..., 2][large_gt_speed] = pred_list_all[large_gt_speed, 1]
            direction_mse = masked_mse_np(move_avg[..., 2], gt_list[..., 1], null_val=0.0)
            print(f"The direction of combined mavg and large speed~({epsilon}) prediction is ", direction_mse)

            direction_mse = masked_mse_np(pred_list_all[large_gt_speed, 1], gt_list[large_gt_speed, 1], null_val=0.0)
            print("The direction mse on large speed gt is ", direction_mse)


        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)
Ejemplo n.º 12
0
def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)

    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size, FLAGS.img_width, FLAGS.seq_length)

    print("Initializing models")
    model = Model()
    lr = FLAGS.lr

    # Prepare tensorboard logging
    logger = Logger(os.path.join(FLAGS.gen_frm_dir, 'board'), model.sess)
    logger.define_item("loss", Logger.Scalar, ())
    logger.define_item("lr", Logger.Scalar, ())
    logger.define_item("mse", Logger.Scalar, ())
    logger.define_item("psnr", Logger.Scalar, ())
    logger.define_item("fmae", Logger.Scalar, ())
    logger.define_item("ssim", Logger.Scalar, ())
    logger.define_item("sharp", Logger.Scalar, ())
    logger.define_item(
        "image",
        Logger.Image,
        (1, 2 * FLAGS.img_width, FLAGS.img_width, FLAGS.img_channel),
        dtype='uint8')

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        logger.add('lr', lr, itr)
        cost = model.train(ims, lr)
        if FLAGS.reverse_input:
            ims_rev = ims[:, ::-1]
            cost += model.train(ims_rev, lr, mask_true)
            cost = cost / 2
        logger.add('loss', cost, itr)

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr))
            print('training loss: ' + str(cost))

        if itr % FLAGS.test_interval == 0:
            print('test...')
            test_input_handle.begin(do_shuffle=False)
            res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            while (test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_batch()
                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)
                img_gen = model.test(test_dat)

                # concat outputs of different gpus along batch
                # img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(
                    img_gen[:, np.newaxis, :, :, :], FLAGS.patch_size)
                # MSE per frame
                for i in range(1):
                    x = test_ims[:, -1, :, :, 0]
                    gx = img_gen[:, :, :, 0]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                    for b in range(FLAGS.batch_size):
                        sharp[i] += np.max(
                            cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                        score, _ = compare_ssim(pred_frm[b],
                                                real_frm[b],
                                                full=True)
                        ssim[i] += score

                # save prediction examples
                if batch_id == 1:
                    sel = np.random.randint(FLAGS.batch_size)
                    img_seq_pd = img_gen[sel]
                    img_seq_gt = test_ims[sel, -1]
                    h, w = img_gen.shape[1:3]
                    out_img = np.zeros((1, h * 2, w * 1, FLAGS.img_channel),
                                       dtype='uint8')
                    for i, img_seq in enumerate([img_seq_gt, img_seq_pd]):
                        img = img_seq
                        img = np.maximum(img, 0)
                        img = np.uint8(img * 10)
                        img = np.minimum(img, 255)
                        out_img[0, (i * h):(i * h + h), :] = img
                    logger.add("image", out_img, itr)

                test_input_handle.next()
            avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
            logger.add('mse', avg_mse, itr)
            print('mse per seq: ' + str(avg_mse))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] / (batch_id * FLAGS.batch_size))
            psnr = np.asarray(psnr, dtype=np.float32) / batch_id
            fmae = np.asarray(fmae, dtype=np.float32) / batch_id
            ssim = np.asarray(ssim, dtype=np.float32) / \
                (FLAGS.batch_size * batch_id)
            sharp = np.asarray(sharp, dtype=np.float32) / \
                (FLAGS.batch_size * batch_id)
            print('psnr per frame: ' + str(np.mean(psnr)))
            logger.add('psnr', np.mean(psnr), itr)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(psnr[i])
            print('fmae per frame: ' + str(np.mean(fmae)))
            logger.add('fmae', np.mean(fmae), itr)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(fmae[i])
            print('ssim per frame: ' + str(np.mean(ssim)))
            logger.add('ssim', np.mean(ssim), itr)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(ssim[i])
            print('sharpness per frame: ' + str(np.mean(sharp)))
            logger.add('sharp', np.mean(sharp), itr)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(sharp[i])

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        train_input_handle.next()
Ejemplo n.º 13
0
def main(argv=None):

    # FLAGS.save_dir += FLAGS.dataset_name + str(FLAGS.seq_length) + FLAGS.num_hidden + 'squash'
    heading_dict = {1: 1, 2: 85, 3: 170, 4: 255, 0: 0}
    heading = FLAGS.heading
    loss_func = FLAGS.loss_func
    FLAGS.save_dir += FLAGS.dataset_name + str(
        FLAGS.seq_length
    ) + FLAGS.num_hidden + 'squash' + FLAGS.loss_func + str(heading)
    FLAGS.pretrained_model = FLAGS.save_dir

    test_data_paths = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name,
                                   FLAGS.dataset_name + '_' + FLAGS.mode)
    sub_files = preprocess.list_filenames(test_data_paths, [])

    output_path = f'./Results/predrnn/t{FLAGS.test_time}_{FLAGS.mode}/{FLAGS.loss_func}'
    # output_path = f'./Results/predrnn/t14/'
    preprocess.create_directory_structure(output_path)

    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    heading_table = np.array([[0, 0], [-1, 1], [1, 1], [-1, -1], [1, -1]],
                             dtype=np.float32)

    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    # dims = train_input_handle.dims
    print("Initializing models", flush=True)
    model = Model()

    avg_mse = 0
    gt_list = []
    pred_list = []
    pred_list_all = []
    pred_vec = []
    move_avg = []

    for f in sub_files:
        with h5py.File(os.path.join(test_data_paths, f), 'r') as h5_file:
            data = h5_file['array'][()]
            # get relevant training data pieces
            data = [
                data[y - FLAGS.input_length:y + FLAGS.seq_length -
                     FLAGS.input_length] for y in indicies
            ]
            test_ims = np.stack(data, axis=0)
            batch_size = len(indicies)
            gt_list.append(test_ims[:, FLAGS.input_length:, :, :, 1:])

            tem_data = test_ims.copy()
            heading_image = test_ims[:, :, :, :, 2]
            heading_image = (heading_image // 85).astype(np.int8) + 1
            heading_image[tem_data[:, :, :, :, 2] == 0] = 0

            # convert the data into speed vectors
            heading_selected = np.zeros_like(heading_image, np.int8)
            heading_selected[heading_image == heading] = heading
            heading_image = heading_selected
            heading_image = heading_table[heading_image]
            speed_on_axis = np.expand_dims(
                test_ims[:, :, :, :, 1].astype(np.float32) / 255.0 /
                np.sqrt(2),
                axis=-1)
            test_ims = speed_on_axis * heading_image

            test_dat = preprocess.reshape_patch(test_ims,
                                                FLAGS.patch_size_width,
                                                FLAGS.patch_size_height)

            mask_true = np.zeros(
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_height, FLAGS.img_width, FLAGS.patch_size_height *
                 FLAGS.patch_size_width * FLAGS.img_channel))

            img_gen = model.test(test_dat, mask_true, batch_size)

            # concat outputs of different gpus along batch
            img_gen = np.concatenate(img_gen)
            img_gen = preprocess.reshape_patch_back(img_gen,
                                                    FLAGS.patch_size_width,
                                                    FLAGS.patch_size_height)

            outfile = os.path.join(output_path, FLAGS.dataset_name,
                                   FLAGS.dataset_name + '_test',
                                   f'{heading}' + f)
            preprocess.write_data(img_gen, outfile)