Example #1
0
def main():
    # 파라미터 로드
    args = parse_args()

    # 리소스 로드
    if torch.cuda.is_available():
        device = torch.device(args.device)
    else:
        device = torch.device("cpu")
    model = MIM(args).to(device)
    print(model)
    print('The model is loaded!\n')

    # 데이터셋 로드
    train_input_handle, test_input_handle = datasets_factory.data_provider(args.dataset_name,
                                                                           args.train_data_paths,
                                                                           args.valid_data_paths,
                                                                           args.batch_size * args.n_gpu,
                                                                           args.img_width,
                                                                           seq_length=args.total_length,
                                                                           is_training=True)  # n 64 64 1 로 나옴

    # with torch.set_grad_enabled(True):
    if args.pretrained_model:
        model.load(args.pretrained_model)

    eta = args.sampling_start_value  # 1.0

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    MSELoss = torch.nn.MSELoss()

    for itr in range(1, args.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)

        ims = train_input_handle.get_batch()
        ims_reverse = None
        if args.reverse_img:
            ims_reverse = ims[:, :, :, ::-1]
            ims_reverse = preprocess.reshape_patch(ims_reverse, args.patch_size)
        ims = preprocess.reshape_patch(ims, args.patch_size)
        eta, real_input_flag = schedule_sampling(eta, itr, args)

        loss = trainer.trainer(model, ims, real_input_flag, args, itr, ims_reverse, device, optimizer, MSELoss)

        if itr % args.snapshot_interval == 0:
            model.save(itr)

        if itr % args.test_interval == 0:
            trainer.test(model, test_input_handle, args, itr)

        if itr % args.display_interval == 0:
            print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'itr: ' + str(itr))
            print('training loss: ' + str(loss))

        train_input_handle.next()

        del loss
Example #2
0
def main(argv=None):

    print("Initializing models")
    model = Model()
    lr = 0.001

    delta = 0.00002
    base = 0.99998
    eta = 1

    for itr in range(1, 21):
        ims = np.random.rand(1, 20, 64, 64, 1)
        ims = preprocess.reshape_patch(ims, 4)

        if itr < 50000:
            eta -= delta
        else:
            eta = 0.0
        random_flip = np.random.random_sample((1, 9))
        true_token = (random_flip < eta)
        #true_token = (random_flip < pow(base,itr))
        mask_true = np.zeros([1, 9, 16, 16, 16], 'float32')
        for i in range(1):
            for j in range(9):
                if true_token[i, j]:
                    mask_true[i, j, :] = 1
                else:
                    mask_true[i, j, :] = 0
        cost, img_gen = model.train(ims, lr, mask_true)

        print('itr: ' + str(itr))
        print('training loss: ' + str(cost))
Example #3
0
def get_datasets(data_file,
                 n_train=8192,
                 n_valid=2048,
                 seq_len=36,
                 patch_size=1,
                 **kwargs):
    """Factory function for the datasets.

    FIXME: assumes overlapping sequences in the train/val split.
    """
    with np.load(data_file) as f:
        data = f['temp'].astype(np.float32)

    # Reshape data into patches
    # The reshape function assumes shape (N, T, C, H, W)
    # but here the sample and sequence dims are merged: (T, H, W)
    # so we add a temporary dummy batch dim and channel dim: (1, T, 1, H, W)
    data = data[None, :, None, :, :]
    data = reshape_patch(data, patch_size)
    # Remove the dummy batch dim
    data = data[0]

    # Split into train/val
    n_train_seq = n_train + seq_len - 1
    train_data, valid_data = data[:n_train_seq], data[n_train_seq:]

    # Construct the PyTorch datasets
    train_set = ClimateDataset(train_data, seq_len=seq_len, **kwargs)
    valid_set = ClimateDataset(valid_data, seq_len=seq_len, **kwargs)
    return train_set, valid_set, {}
Example #4
0
def main(argv=None):

    model = Model()

    while True:
        line = input()
        try:
            inf, outf = line.split(',')
            img = np.array(Level3File(inf).sym_block[0][0]['data'],
                           dtype='float32')
            h, w = img.shape
            nw = FLAGS.img_width
            nh = h * nw // w
            img = cv2.resize(img, (nh, nw), interpolation=cv2.INTER_AREA)
            img = img[np.newaxis, np.newaxis, :, :, np.newaxis]
            img = preprocess.reshape_patch(img, FLAGS.patch_size)

            pred = model.inference(img)
            pred = preprocess.reshape_patch_back(pred[:, np.newaxis, :],
                                                 FLAGS.patch_size)
            pred = cv2.resize(pred[0, 0, :, :, 0], (h, w),
                              interpolation=cv2.INTER_CUBIC)

            imsave(outf, pred, metadata={'axis': 'YX'})
            print('done')

        except Exception as e:
            print('failed:', e)
Example #5
0
    def __init__(self,
                 data_file,
                 n_samples=None,
                 sample_shape=(20, 1, 64, 64),
                 patch_size=4):
        self.data_file = data_file

        # Load the data
        with np.load(data_file) as f:
            d = f['input_raw_data']

        # Reshape and select requested number of samples
        d = d.reshape((-1, ) + sample_shape)
        if n_samples is not None:
            d = d[:n_samples]

        # The original PredRNN++ code applies this patch transform which
        # breaks the image up into patch_size patches stacked as channels.
        d = reshape_patch(d, patch_size)

        # Convert to Torch tensor
        self.data = torch.tensor(d)
Example #6
0
def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)

    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size, FLAGS.img_width)

    print('Initializing models')
    model = Model()
    lr = FLAGS.lr

    delta = 0.00002
    base = 0.99998
    eta = 1

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        if itr < 50000:
            eta -= delta
        else:
            eta = 0.0
        random_flip = np.random.random_sample(
            (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1))
        true_token = (random_flip < eta)
        #true_token = (random_flip < pow(base,itr))
        ones = np.ones((FLAGS.img_width / FLAGS.patch_size,
                        FLAGS.img_width / FLAGS.patch_size,
                        FLAGS.patch_size**2 * FLAGS.img_channel))
        zeros = np.zeros((FLAGS.img_width / FLAGS.patch_size,
                          FLAGS.img_width / FLAGS.patch_size,
                          FLAGS.patch_size**2 * FLAGS.img_channel))
        mask_true = []
        for i in range(FLAGS.batch_size):
            for j in range(FLAGS.seq_length - FLAGS.input_length - 1):
                if true_token[i, j]:
                    mask_true.append(ones)
                else:
                    mask_true.append(zeros)
        mask_true = np.array(mask_true)
        mask_true = np.reshape(
            mask_true,
            (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
             FLAGS.img_width / FLAGS.patch_size, FLAGS.img_width /
             FLAGS.patch_size, FLAGS.patch_size**2 * FLAGS.img_channel))
        cost = model.train(ims, lr, mask_true)
        if FLAGS.reverse_input:
            ims_rev = ims[:, ::-1]
            cost += model.train(ims_rev, lr, mask_true)
            cost = cost / 2

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr))
            print('training loss: ' + str(cost))

        if itr % FLAGS.test_interval == 0:
            print('test...')
            test_input_handle.begin(do_shuffle=False)
            res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            mask_true = np.zeros(
                (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_width / FLAGS.patch_size,
                 FLAGS.img_width / FLAGS.patch_size,
                 FLAGS.patch_size**2 * FLAGS.img_channel))
            while (test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_batch()
                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)
                img_gen = model.test(test_dat, mask_true)

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(
                    img_gen, FLAGS.patch_size)
                # MSE per frame
                for i in range(FLAGS.seq_length - FLAGS.input_length):
                    x = test_ims[:, i + FLAGS.input_length, :, :, 0]
                    gx = img_gen[:, i, :, :, 0]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                    for b in range(FLAGS.batch_size):
                        sharp[i] += np.max(
                            cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                        score, _ = compare_ssim(pred_frm[b],
                                                real_frm[b],
                                                full=True)
                        ssim[i] += score

                # save prediction examples
                if batch_id <= 10:
                    path = os.path.join(res_path, str(batch_id))
                    os.mkdir(path)
                    for i in range(FLAGS.seq_length):
                        name = 'gt' + str(i + 1) + '.png'
                        file_name = os.path.join(path, name)
                        img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                        cv2.imwrite(file_name, img_gt)
                    for i in range(FLAGS.seq_length - FLAGS.input_length):
                        name = 'pd' + str(i + 1 + FLAGS.input_length) + '.png'
                        file_name = os.path.join(path, name)
                        img_pd = img_gen[0, i, :, :, :]
                        img_pd = np.maximum(img_pd, 0)
                        img_pd = np.minimum(img_pd, 1)
                        img_pd = np.uint8(img_pd * 255)
                        cv2.imwrite(file_name, img_pd)
                test_input_handle.next()
            avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
            print('mse per seq: ' + str(avg_mse))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] / (batch_id * FLAGS.batch_size))
            psnr = np.asarray(psnr, dtype=np.float32) / batch_id
            fmae = np.asarray(fmae, dtype=np.float32) / batch_id
            ssim = np.asarray(ssim,
                              dtype=np.float32) / (FLAGS.batch_size * batch_id)
            sharp = np.asarray(
                sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id)
            print('psnr per frame: ' + str(np.mean(psnr)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(psnr[i])
            print('fmae per frame: ' + str(np.mean(fmae)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(fmae[i])
            print('ssim per frame: ' + str(np.mean(ssim)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(ssim[i])
            print('sharpness per frame: ' + str(np.mean(sharp)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(sharp[i])

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        train_input_handle.next()
Example #7
0
def main(argv=None):

    # FLAGS.save_dir += FLAGS.dataset_name
    # FLAGS.gen_frm_dir += FLAGS.dataset_name
    # if tf.io.gfile.exists(FLAGS.save_dir):
    #     tf.io.gfile.rmtree(FLAGS.save_dir)
    # tf.io.gfile.makedirs(FLAGS.save_dir)
    # if tf.io.gfile.exists(FLAGS.gen_frm_dir):
    #     tf.io.gfile.rmtree(FLAGS.gen_frm_dir)
    # tf.io.gfile.makedirs(FLAGS.gen_frm_dir)

    FLAGS.save_dir += FLAGS.dataset_name + str(
        FLAGS.seq_length) + FLAGS.num_hidden
    FLAGS.best_model = FLAGS.save_dir + f'/best_channels{FLAGS.img_channel}_weighted.ckpt'
    FLAGS.gen_frm_dir += FLAGS.dataset_name
    if not tf.io.gfile.exists(FLAGS.save_dir):
        # tf.io.gfile.rmtree(FLAGS.save_dir)
        tf.io.gfile.makedirs(FLAGS.save_dir)
    else:
        FLAGS.pretrained_model = FLAGS.save_dir
    if not tf.io.gfile.exists(FLAGS.gen_frm_dir):
        # tf.io.gfile.rmtree(FLAGS.gen_frm_dir)
        tf.io.gfile.makedirs(FLAGS.gen_frm_dir)

    process_data_dir = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name,
                                    'process_0.5')
    node_pos_file_2in1 = os.path.join(process_data_dir, 'node_pos_0.5.npy')
    node_pos = np.load(node_pos_file_2in1)

    train_data_paths = os.path.join(FLAGS.train_data_paths, FLAGS.dataset_name,
                                    FLAGS.dataset_name + '_training')
    valid_data_paths = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name,
                                    FLAGS.dataset_name + '_validation')
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, train_data_paths, valid_data_paths,
        FLAGS.batch_file, True, FLAGS.input_length,
        FLAGS.seq_length - FLAGS.input_length)

    cities = ['Berlin', 'Istanbul', 'Moscow']
    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    # dims = train_input_handle.dims
    print("Initializing models", flush=True)
    model = Model()
    lr = FLAGS.lr

    delta = 0.2
    base = 0.99998
    eta = 1
    min_val_loss = 1.0

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        imss = train_input_handle.get_batch()
        imss = imss[..., :FLAGS.img_channel]

        imss = preprocess.reshape_patch(imss, FLAGS.patch_size_width,
                                        FLAGS.patch_size_height)
        num_batches = imss.shape[0]
        for bi in range(0, num_batches, FLAGS.batch_size):
            ims = imss[bi:bi + FLAGS.batch_size]
            FLAGS.img_height = ims.shape[2]
            FLAGS.img_width = ims.shape[3]
            batch_size = ims.shape[0]
            if itr < 10:
                eta -= delta
            else:
                eta = 0.0
            random_flip = np.random.random_sample(
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1))
            true_token = (random_flip < eta)
            ones = np.ones((FLAGS.img_height, FLAGS.img_width,
                            int(FLAGS.patch_size_height *
                                FLAGS.patch_size_width * FLAGS.img_channel)))
            zeros = np.zeros((int(FLAGS.img_height), int(FLAGS.img_width),
                              int(FLAGS.patch_size_height *
                                  FLAGS.patch_size_width * FLAGS.img_channel)))
            mask_true = []
            for i in range(batch_size):
                for j in range(FLAGS.seq_length - FLAGS.input_length - 1):
                    if true_token[i, j]:
                        mask_true.append(ones)
                    else:
                        mask_true.append(zeros)
            mask_true = np.array(mask_true)
            mask_true = np.reshape(
                mask_true,
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 int(FLAGS.img_height), int(FLAGS.img_width),
                 int(FLAGS.patch_size_height * FLAGS.patch_size_width *
                     FLAGS.img_channel)))
            cost = model.train(ims, lr, mask_true, batch_size)

            if FLAGS.reverse_input:
                ims_rev = ims[:, ::-1]
                cost += model.train(ims_rev, lr, mask_true, batch_size)
                cost = cost / 2

            # cost = cost / (batch_size * FLAGS.img_height * FLAGS.img_width * FLAGS.patch_size_height *
            #                FLAGS.patch_size_width * FLAGS.img_channel * (FLAGS.seq_length - 1))

            if itr % FLAGS.display_interval == 0:
                print('itr: ' + str(itr), flush=True)
                print('training loss: ' + str(cost), flush=True)

        train_input_handle.next()
        if itr % FLAGS.test_interval == 0:
            print('test...', flush=True)
            batch_size = len(indicies)
            test_input_handle.begin(do_shuffle=False)
            # res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            # os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            mask_true = np.zeros(
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_height, FLAGS.img_width, FLAGS.patch_size_height *
                 FLAGS.patch_size_width * FLAGS.img_channel))

            gt_list = []
            pred_list = []

            while (test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_test_batch(indicies)
                # get the selected channels
                test_ims = test_ims[..., :FLAGS.img_channel]

                gt_list.append(test_ims[:, FLAGS.input_length:, :, :, :])
                test_dat = preprocess.reshape_patch(test_ims,
                                                    FLAGS.patch_size_width,
                                                    FLAGS.patch_size_height)

                img_gen = model.test(test_dat, mask_true, batch_size)

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(
                    img_gen, FLAGS.patch_size_width, FLAGS.patch_size_height)
                pred_list.append(img_gen)
                # MSE per frame
                for i in range(FLAGS.seq_length - FLAGS.input_length):
                    x = test_ims[:, i + FLAGS.input_length, :, :, :]
                    gx = img_gen[:, i, :, :, :]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)

                test_input_handle.next()

            avg_mse = avg_mse / (batch_id * batch_size * FLAGS.img_height *
                                 FLAGS.img_width * FLAGS.patch_size_height *
                                 FLAGS.patch_size_width * FLAGS.img_channel *
                                 len(img_mse))
            print('mse per seq: ' + str(avg_mse), flush=True)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] /
                      (batch_id * batch_size * FLAGS.img_height *
                       FLAGS.img_width * FLAGS.patch_size_height *
                       FLAGS.patch_size_width * FLAGS.img_channel),
                      flush=True)

            gt_list = np.stack(gt_list, axis=0)
            pred_list = np.stack(pred_list, axis=0)
            mse = masked_mse_np(pred_list, gt_list, null_val=np.nan)
            volume_mse = masked_mse_np(pred_list[..., 0],
                                       gt_list[..., 0],
                                       null_val=np.nan)
            speed_mse = masked_mse_np(pred_list[..., 1],
                                      gt_list[..., 1],
                                      null_val=np.nan)

            print("The output mse is ", mse, flush=True)
            print("The volume mse is ", volume_mse, flush=True)
            print("The speed mse is ", speed_mse, flush=True)
            if FLAGS.img_channel == 3:
                direction_mse = masked_mse_np(pred_list[..., 2],
                                              gt_list[..., 2],
                                              null_val=np.nan)
                print("The direction mse is ", direction_mse, flush=True)

            if min_val_loss > mse:
                min_val_loss = mse
                print("Current Min Val Loss is ", min_val_loss)
                model.save_to_best_mode()

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)
Example #8
0
def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)

    train_data_paths = os.path.join(
        FLAGS.train_data_paths, FLAGS.dataset_name,
        'train_speed_down_sample{}.npz'.format(FLAGS.down_sample))
    valid_data_paths = os.path.join(
        FLAGS.valid_data_paths, FLAGS.dataset_name,
        'valid_speed_down_sample{}.npz'.format(FLAGS.down_sample))
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, train_data_paths, valid_data_paths,
        FLAGS.batch_size, True, FLAGS.down_sample, FLAGS.input_length,
        FLAGS.seq_length - FLAGS.input_length)

    cities = ['Berlin', 'Istanbul', 'Moscow']
    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    dims = train_input_handle.dims
    FLAGS.img_height = dims[-2]
    FLAGS.img_width = dims[-1]
    print("Initializing models", flush=True)
    model = Model()
    lr = FLAGS.lr

    delta = 0.00002
    base = 0.99998
    eta = 1

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        if itr < 50000:
            eta -= delta
        else:
            eta = 0.0
        random_flip = np.random.random_sample(
            (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1))
        true_token = (random_flip < eta)
        #true_token = (random_flip < pow(base,itr))
        ones = np.ones((FLAGS.img_height, FLAGS.img_width,
                        int(FLAGS.patch_size**2 * FLAGS.img_channel)))
        zeros = np.zeros((int(FLAGS.img_height), int(FLAGS.img_width),
                          int(FLAGS.patch_size**2 * FLAGS.img_channel)))
        mask_true = []
        for i in range(FLAGS.batch_size):
            for j in range(FLAGS.seq_length - FLAGS.input_length - 1):
                if true_token[i, j]:
                    mask_true.append(ones)
                else:
                    mask_true.append(zeros)
        mask_true = np.array(mask_true)
        mask_true = np.reshape(
            mask_true,
            (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
             int(FLAGS.img_height), int(FLAGS.img_width),
             int(FLAGS.patch_size**2 * FLAGS.img_channel)))
        cost = model.train(ims, lr, mask_true)
        if FLAGS.reverse_input:
            ims_rev = ims[:, ::-1]
            cost += model.train(ims_rev, lr, mask_true)
            cost = cost / 2

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr), flush=True)
            print('training loss: ' + str(cost), flush=True)

        if itr % FLAGS.test_interval == 0:
            print('test...', flush=True)
            test_input_handle.begin(do_shuffle=False)
            res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            mask_true = np.zeros(
                (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_height, FLAGS.img_width,
                 FLAGS.patch_size**2 * FLAGS.img_channel))
            while (test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_batch()
                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)
                img_gen = model.test(test_dat, mask_true)

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(
                    img_gen, FLAGS.patch_size)
                # MSE per frame
                for i in range(FLAGS.seq_length - FLAGS.input_length):
                    x = test_ims[:, i + FLAGS.input_length, :, :, 0]
                    gx = img_gen[:, i, :, :, 0]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                    for b in range(FLAGS.batch_size):
                        sharp[i] += np.max(
                            cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                        # score, _ = compare_ssim(pred_frm[b],real_frm[b],full=True)
                        # ssim[i] += score

                # save prediction examples
                if batch_id <= 10:
                    path = os.path.join(res_path, str(batch_id))
                    os.mkdir(path)
                    for i in range(FLAGS.seq_length):
                        name = 'gt' + str(i + 1) + '.png'
                        file_name = os.path.join(path, name)
                        img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                        cv2.imwrite(file_name, img_gt)
                    for i in range(FLAGS.seq_length - FLAGS.input_length):
                        name = 'pd' + str(i + 1 + FLAGS.input_length) + '.png'
                        file_name = os.path.join(path, name)
                        img_pd = img_gen[0, i, :, :, :]
                        img_pd = np.maximum(img_pd, 0)
                        img_pd = np.minimum(img_pd, 1)
                        img_pd = np.uint8(img_pd * 255)
                        cv2.imwrite(file_name, img_pd)
                test_input_handle.next()
            avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
            print('mse per seq: ' + str(avg_mse), flush=True)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] / (batch_id * FLAGS.batch_size))
            psnr = np.asarray(psnr, dtype=np.float32) / batch_id
            fmae = np.asarray(fmae, dtype=np.float32) / batch_id
            sharp = np.asarray(
                sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id)
            print('psnr per frame: ' + str(np.mean(psnr)), flush=True)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(psnr[i], flush=True)
            print('fmae per frame: ' + str(np.mean(fmae)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(fmae[i], flush=True)
            print('sharpness per frame: ' + str(np.mean(sharp)))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(sharp[i], flush=True)

            # test with file
            valid_data_path = os.path.join(
                FLAGS.train_data_paths, FLAGS.dataset_name,
                '{}_validation'.format(FLAGS.dataset_name))
            files = list_filenames(valid_data_path)
            output_all = []
            labels_all = []
            for f in files:
                valid_file = valid_data_path + '/' + f
                valid_input, raw_output = datasets_factory.test_validation_provider(
                    valid_file,
                    indicies,
                    down_sample=FLAGS.down_sample,
                    seq_len=FLAGS.input_length,
                    horizon=FLAGS.seq_length - FLAGS.input_length)
                valid_input = valid_input.astype(np.float) / 255.0
                labels_all.append(raw_output)
                num_tests = len(indicies)
                num_partitions = int(np.ceil(num_tests / FLAGS.batch_size))
                for i in range(num_partitions):
                    valid_input_i = valid_input[i * FLAGS.batch_size:(i + 1) *
                                                FLAGS.batch_size]
                    num_input_i = valid_input_i.shape[0]
                    if num_input_i < FLAGS.batch_size:
                        zeros_fill_in = np.zeros(
                            (FLAGS.batch_size - num_input_i, FLAGS.seq_length,
                             FLAGS.img_height, FLAGS.img_width,
                             FLAGS.img_channel))
                        valid_input_i = np.concatenate(
                            [valid_input_i, zeros_fill_in], axis=0)
                    img_gen = model.test(valid_input_i, mask_true)
                    output_all.append(img_gen[0][:num_input_i])

            output_all = np.concatenate(output_all, axis=0)
            labels_all = np.concatenate(labels_all, axis=0)
            origin_height = labels_all.shape[-2]
            origin_width = labels_all.shape[-3]
            output_resize = []
            for i in range(output_all.shape[0]):
                output_i = []
                for j in range(output_all.shape[1]):
                    tmp_data = output_all[i, j, 1, :, :]
                    tmp_data = cv2.resize(tmp_data,
                                          (origin_height, origin_width))
                    tmp_data = np.expand_dims(tmp_data, axis=0)
                    output_i.append(tmp_data)
                output_i = np.stack(output_i, axis=0)
                output_resize.append(output_i)
            output_resize = np.stack(output_resize, axis=0)

            output_resize *= 255.0
            labels_all = np.expand_dims(labels_all[..., 1], axis=2)
            valid_mse = masked_mse_np(output_resize, labels_all, np.nan)

            print("validation mse is ", valid_mse, flush=True)

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        train_input_handle.next()
Example #9
0
def test(model, test_input_handle, configs, itr=None):
    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...')
    res_path = os.path.join(configs.gen_frm_dir, str(itr))
    os.mkdir(res_path)
    avg_mse = 0
    batch_id = 0
    img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []

    output_length = configs.total_length - configs.input_length  #20-10
    for i in range(configs.total_length - configs.input_length):
        img_mse.append(0)
        ssim.append(0)
        psnr.append(0)
        fmae.append(0)
        sharp.append(0)

    real_input_flag = np.zeros(  #(4 , 20-10-1 , 140//4 , 140//4 , 2^2*1)
        (configs.batch_size, configs.total_length - configs.input_length - 1,
         configs.img_width // configs.patch_size,
         configs.img_width // configs.patch_size,
         configs.patch_size**2 * configs.img_channel))

    for ind, test_input in enumerate(test_input_handle):
        test_ims = test_input.numpy(
        )  # test_ims shape: (batch, seq, channels, height, width)
        test_ims = np.transpose(test_ims, (0, 1, 3, 4, 2))
        batch_id = batch_id + 1

        test_dat = preprocess.reshape_patch(test_input, configs.patch_size)
        img_gen = model.test(test_dat, real_input_flag)

        img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size)
        img_out = img_gen[:, -output_length:]
        # MSE per frame
        for i in range(output_length):
            x = test_ims[:, i + configs.input_length, :, :, :]
            gx = img_out[:, i, :, :, :]
            fmae[i] += metrics.batch_mae_frame_float(gx, x)
            gx = np.maximum(gx, 0)
            gx = np.minimum(gx, 1)
            mse = np.square(x - gx).sum()
            img_mse[i] += mse
            avg_mse += mse

            real_frm = np.uint8(x * 255)
            pred_frm = np.uint8(gx * 255)
            psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
            for b in range(configs.batch_size):
                score, _ = compare_ssim(pred_frm[b],
                                        real_frm[b],
                                        full=True,
                                        multichannel=True)
                ssim[i] += score
                sharp[i] += np.max(
                    cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))

        # save prediction examples
        if batch_id <= configs.num_save_samples:
            path = os.path.join(res_path, str(batch_id))
            os.mkdir(path)
            for i in range(configs.total_length):
                name = 'gt' + str(i + 1) + '.png'
                file_name = os.path.join(path, name)
                img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                cv2.imwrite(file_name, img_gt)
            for i in range(output_length):
                name = 'pd' + str(i + 1 + configs.input_length) + '.png'
                file_name = os.path.join(path, name)
                img_pd = img_out[0, i, :, :, :]
                img_pd = np.maximum(img_pd, 0)
                img_pd = np.minimum(img_pd, 1)
                img_pd = np.uint8(img_pd * 255)
                cv2.imwrite(file_name, img_pd)

    avg_mse = avg_mse / (batch_id * configs.batch_size)
    print('mse per seq: ' + str(avg_mse))
    for i in range(configs.total_length - configs.input_length):
        print(img_mse[i] / (batch_id * configs.batch_size))

    ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id)
    psnr = np.asarray(psnr, dtype=np.float32) / batch_id
    fmae = np.asarray(fmae, dtype=np.float32) / batch_id
    sharp = np.asarray(sharp,
                       dtype=np.float32) / (configs.batch_size * batch_id)
    print('ssim per frame: ' + str(np.mean(ssim)))
    for i in range(configs.total_length - configs.input_length):
        print(ssim[i])
    print('psnr per frame: ' + str(np.mean(psnr)))
    for i in range(configs.total_length - configs.input_length):
        print(psnr[i])
    print('fmae per frame: ' + str(np.mean(fmae)))
    for i in range(configs.total_length - configs.input_length):
        print(fmae[i])
    print('sharpness per frame: ' + str(np.mean(sharp)))
    for i in range(configs.total_length - configs.input_length):
        print(sharp[i])

    return avg_mse, ssim, psnr, fmae, sharp
Example #10
0
def main(argv=None):

    tf.disable_eager_execution()  #toegevoegd anders error

    # load data
    _, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.test_data_paths,
        FLAGS.batch_size, FLAGS.img_width)

    print("Initializing models")
    model = Model()
    lr = FLAGS.lr

    print('test...')
    test_input_handle.begin(do_shuffle=False)
    res_path = os.path.join(FLAGS.gen_frm_dir, 'test')
    os.mkdir(res_path)
    avg_mse = 0
    batch_id = 0
    img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        img_mse.append(0)
        ssim.append(0)
        psnr.append(0)
        fmae.append(0)
        sharp.append(0)
    mask_true = np.zeros(
        (FLAGS.batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
         int(FLAGS.img_height / FLAGS.patch_size),
         int(FLAGS.img_width / FLAGS.patch_size),
         FLAGS.patch_size**2 * FLAGS.img_channel))
    while (test_input_handle.no_batch_left() == False):
        batch_id = batch_id + 1
        test_ims = test_input_handle.get_batch()
        test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)
        img_gen = model.test(test_dat, mask_true)

        # concat outputs of different gpus along batch
        img_gen = np.concatenate(img_gen)
        img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size)
        # MSE per frame
        for i in xrange(FLAGS.seq_length - FLAGS.input_length):
            x = test_ims[:, i + FLAGS.input_length, :, :, 0]
            gx = img_gen[:, i, :, :, 0]
            fmae[i] += metrics.batch_mae_frame_float(gx, x)
            gx = np.maximum(gx, 0)
            gx = np.minimum(gx, 1)
            mse = np.square(x - gx).sum()
            img_mse[i] += mse
            avg_mse += mse

            real_frm = np.uint8(x * 255)
            pred_frm = np.uint8(gx * 255)
            psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
            for b in xrange(FLAGS.batch_size):
                sharp[i] += np.max(
                    cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                score, _ = compare_ssim(pred_frm[b], real_frm[b], full=True)
                ssim[i] += score

        # save prediction examples
        if batch_id <= 10:
            path = os.path.join(res_path, str(batch_id))
            os.mkdir(path)
            for i in xrange(FLAGS.seq_length):
                name = 'gt' + str(i + 1) + '.png'
                file_name = os.path.join(path, name)
                img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                cv2.imwrite(file_name, img_gt)
            for i in xrange(FLAGS.seq_length - FLAGS.input_length):
                name = 'pd' + str(i + 1 + FLAGS.input_length) + '.png'
                file_name = os.path.join(path, name)
                img_pd = img_gen[0, i, :, :, :]
                img_pd = np.maximum(img_pd, 0)
                img_pd = np.minimum(img_pd, 1)
                img_pd = np.uint8(img_pd * 255)
                cv2.imwrite(file_name, img_pd)
        test_input_handle.next()
    avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
    print('mse per seq: ' + str(avg_mse))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(img_mse[i] / (batch_id * FLAGS.batch_size))
    psnr = np.asarray(psnr, dtype=np.float32) / batch_id
    fmae = np.asarray(fmae, dtype=np.float32) / batch_id
    ssim = np.asarray(ssim, dtype=np.float32) / (FLAGS.batch_size * batch_id)
    sharp = np.asarray(sharp, dtype=np.float32) / (FLAGS.batch_size * batch_id)
    print('psnr per frame: ' + str(np.mean(psnr)))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(psnr[i])
    print('fmae per frame: ' + str(np.mean(fmae)))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(fmae[i])
    print('ssim per frame: ' + str(np.mean(ssim)))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(ssim[i])
    print('sharpness per frame: ' + str(np.mean(sharp)))
    for i in xrange(FLAGS.seq_length - FLAGS.input_length):
        print(sharp[i])
def main(argv=None):
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)
    if tf.gfile.Exists(FLAGS.log_dir):
        tf.gfile.DeleteRecursively(FLAGS.log_dir)
    tf.gfile.MakeDirs(FLAGS.log_dir)

    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size,
        [FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel],
        FLAGS.seq_length)

    print('Initializing models')
    model = Model()

    #%%
    test_input_handle.begin(do_shuffle=False)
    totalDataLen = int(test_input_handle.total())
    print('totalDataLen=', totalDataLen)
    for itr in range(1):
        print('inference...')
        res_path = os.path.join(FLAGS.gen_frm_dir, 'images' + str(itr))
        os.mkdir(res_path)
        avg_mse = 0
        batch_id = 0
        img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
        for i in range(FLAGS.seq_length):
            img_mse.append(0)
            ssim.append(0)
            psnr.append(0)
            fmae.append(0)
            sharp.append(0)

        mask_true = np.ones(
            (FLAGS.batch_size, FLAGS.seq_length, FLAGS.img_height,
             FLAGS.img_width, FLAGS.img_channel))
        for num_batch in range(FLAGS.batch_size):
            for num_seq in range(FLAGS.seq_length):
                # 0 2 4 6 8 10 skip
                if (num_seq % 2 == 0):
                    continue
                # 1 3 5 7 9 replace random noise
                else:
                    mask_true[num_batch, num_seq] = np.zeros(
                        (FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel))

        mask_true = preprocess.reshape_patch(mask_true, FLAGS.patch_size)
        ###while(test_input_handle.no_batch_left() == False):
        while (batch_id < totalDataLen):
            batch_id = batch_id + 1
            print('test get_batch:')
            test_ims, fileName = test_input_handle.get_batch(False)
            test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)

            if FLAGS.reverse_input:
                test_ims_rev = test_dat[:, ::-1]

            img_gen, ims_watch, ims_rev_watch = model.test(
                test_dat, test_ims_rev, mask_true, itr)

            # concat outputs of different gpus along batch
            img_gen = np.concatenate(img_gen)
            img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size)
            ims_watch = np.concatenate(ims_watch)
            ims_watch = preprocess.reshape_patch_back(ims_watch,
                                                      FLAGS.patch_size)
            ims_rev_watch = np.concatenate(ims_rev_watch)
            ims_rev_watch = preprocess.reshape_patch_back(
                ims_rev_watch, FLAGS.patch_size)
            # MSE per frame
            for i in range(FLAGS.seq_length):
                x = test_ims[:, i, :, :, 0]

                # Predict only odd images
                if FLAGS.gen_num == 5:
                    if (i % 2 == 1):
                        gx = img_gen[:, i // 2, :, :, 0]
                    else:
                        gx = test_ims[:, i, :, :, 0]
                # Predict 11 images
                elif FLAGS.gen_num == 11:
                    if (i % 2 == 1):
                        gx = img_gen[:, i, :, :, 0]
                    else:
                        gx = test_ims[:, i, :, :, 0]

                fmae[i] += metrics.batch_mae_frame_float(gx, x)
                gx = np.maximum(gx, 0)
                gx = np.minimum(gx, 1)
                mse = np.square(x - gx).sum()
                img_mse[i] += mse
                avg_mse += mse

                real_frm = np.uint8(x * 255)
                pred_frm = np.uint8(gx * 255)
                psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                for b in range(FLAGS.batch_size):
                    sharp[i] += np.max(
                        cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                    score, _ = compare_ssim(pred_frm[b],
                                            real_frm[b],
                                            full=True)
                    ssim[i] += score

            # save prediction examples
            if batch_id < totalDataLen:
                path = os.path.join(res_path, str(fileName))
                os.mkdir(path)
                for i in range(FLAGS.seq_length):
                    name = 'gt' + str(i + 1) + '.png'
                    file_name = os.path.join(path, name)
                    img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                    cv2.imwrite(file_name, img_gt)

                for i in range(FLAGS.seq_length):
                    name = 'pd' + str(i + 1) + '.png'
                    file_name = os.path.join(path, name)

                    #                   # Predict only odd images
                    if FLAGS.gen_num == 5:
                        if (i % 2 == 1):
                            img_pd = img_gen[0, i // 2, :, :, :]
                        else:
                            img_pd = test_ims[0, i, :, :, :]
                    # Predict 11 images
                    elif FLAGS.gen_num == 11:
                        img_pd = img_gen[0, i, :, :, :]

                    img_pd = np.maximum(img_pd, 0)
                    img_pd = np.minimum(img_pd, 1)
                    img_pd = np.uint8(img_pd * 255)
                    cv2.imwrite(file_name, img_pd)

            test_input_handle.next()
        avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
        print('mse per seq: ' + str(avg_mse))
        for i in range(FLAGS.seq_length):
            print(img_mse[i] / (batch_id * FLAGS.batch_size))
        psnr = np.asarray(psnr, dtype=np.float32) / batch_id
        fmae = np.asarray(fmae, dtype=np.float32) / batch_id
        ssim = np.asarray(ssim,
                          dtype=np.float32) / (FLAGS.batch_size * batch_id)
        sharp = np.asarray(sharp,
                           dtype=np.float32) / (FLAGS.batch_size * batch_id)
        print('psnr per frame: ' + str(np.mean(psnr)))
        for i in range(FLAGS.seq_length):
            print(psnr[i])
        print('fmae per frame: ' + str(np.mean(fmae)))
        for i in range(FLAGS.seq_length):
            print(fmae[i])
        print('ssim per frame: ' + str(np.mean(ssim)))
        for i in range(FLAGS.seq_length):
            print(ssim[i])
        print('sharpness per frame: ' + str(np.mean(sharp)))
        for i in range(FLAGS.seq_length):
            print(sharp[i])
Example #12
0
def test(model, test_input_handle, configs, save_name, hidden_state,
         cell_state, hidden_state_diff, cell_state_diff, st_memory,
         conv_lstm_c, MIMB_oc_w, MIMB_ct_w, MIMN_oc_w, MIMN_ct_w):
    test_input_handle.begin(do_shuffle=False)
    res_path = configs.gen_frm_dir
    if not os.path.isdir(res_path):
        os.mkdir(res_path)
    avg_mse = 0
    batch_id = 0
    img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []

    for i in range(configs.total_length - configs.input_length):
        img_mse.append(0)
        ssim.append(0)
        psnr.append(0)
        fmae.append(0)
        sharp.append(0)

    if configs.img_height > 0:
        height = configs.img_height
    else:
        height = configs.img_width

    real_input_flag = np.zeros(
        (configs.batch_size, configs.total_length - configs.input_length - 1,
         configs.patch_size**2 * configs.img_channel,
         configs.img_width // configs.patch_size,
         height // configs.patch_size))

    with torch.no_grad():
        while not test_input_handle.no_batch_left():
            batch_id = batch_id + 1
            if save_name != 'test_result':
                if batch_id > 100: break

            test_ims = test_input_handle.get_batch()
            test_ims = test_ims[:, :configs.total_length]

            if len(test_ims.shape) > 3:
                test_dat = preprocess.reshape_patch(test_ims,
                                                    configs.patch_size)
            else:
                test_dat = test_ims

            # test_dat = np.split(test_dat, configs.n_gpu)
            # 여기서 debug 바꿔줘야 함 현재 im_gen만 나오게 바껴져 있음 원래는 뭐였는지 살펴보기
            test_dat_tensor = torch.tensor(test_dat,
                                           device=configs.device,
                                           requires_grad=False)
            img_gen = model.forward(test_dat_tensor, real_input_flag,
                                    hidden_state, cell_state,
                                    hidden_state_diff, cell_state_diff,
                                    st_memory, conv_lstm_c, MIMB_oc_w,
                                    MIMB_ct_w, MIMN_oc_w, MIMN_ct_w)
            img_gen = img_gen.clone().detach().to('cpu').numpy()

            # concat outputs of different gpus along batch
            # img_gen = np.concatenate(img_gen)
            if len(img_gen.shape) > 3:
                img_gen = preprocess.reshape_patch_back(
                    img_gen, configs.patch_size)

            # MSE per frame
            for i in range(configs.total_length - configs.input_length):
                x = test_ims[:, i + configs.input_length, :, :, :]
                x = x[:configs.batch_size * configs.n_gpu]
                x = x - np.where(x > 10000,
                                 np.floor_divide(x, 10000) * 10000,
                                 np.zeros_like(x))
                gx = img_gen[:, i, :, :, :]
                fmae[i] += metrics.batch_mae_frame_float(gx, x)
                gx = np.maximum(gx, 0)
                gx = np.minimum(gx, 1)
                mse = np.square(x - gx).sum()
                img_mse[i] += mse
                avg_mse += mse
                real_frm = np.uint8(x * 255)
                pred_frm = np.uint8(gx * 255)
                psnr[i] += metrics.batch_psnr(pred_frm, real_frm)

                for b in range(configs.batch_size):
                    sharp[i] += np.max(
                        cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                    gx_trans = np.transpose(gx[b], (1, 2, 0))
                    x_trans = np.transpose(x[b], (1, 2, 0))
                    score = structural_similarity(gx_trans,
                                                  x_trans,
                                                  multichannel=True)
                    ssim[i] += score

            # save prediction examples
            if batch_id <= configs.num_save_samples:
                path = os.path.join(res_path, str(save_name))
                if not os.path.isdir(path):
                    os.mkdir(path)

                # if len(debug) != 0:
                #     np.save(os.path.join(path, "f.npy"), debug)

                for i in range(configs.total_length):
                    name = 'gt' + str(i + 1) + '.png'
                    file_name = os.path.join(path, name)
                    img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                    if configs.img_channel == 2:
                        img_gt = img_gt[:, :, :1]
                    img_gt = np.transpose(img_gt, (1, 2, 0))
                    cv2.imwrite(file_name, img_gt)

                for i in range(configs.total_length - 1):
                    name = 'pd' + str(i) + '.png'
                    file_name = os.path.join(path, name)
                    img_pd = img_gen[0, i, :, :, :]
                    if configs.img_channel == 2:
                        img_pd = img_pd[:, :, :1]
                    img_pd = np.maximum(img_pd, 0)
                    img_pd = np.minimum(img_pd, 1)
                    img_pd = np.uint8(img_pd * 255)
                    img_pd = np.transpose(img_pd, (1, 2, 0))
                    cv2.imwrite(file_name, img_pd)
            test_input_handle.next()

    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
          'test...' + str(save_name))

    avg_mse = avg_mse / (batch_id * configs.batch_size * configs.n_gpu)
    print('mse per seq: ' + str(avg_mse))

    for i in range(configs.total_length - configs.input_length):
        print(img_mse[i] / (batch_id * configs.batch_size * configs.n_gpu))

    psnr = np.asarray(psnr, dtype=np.float32) / batch_id
    ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id)
    fmae = np.asarray(fmae, dtype=np.float32) / batch_id
    sharp = np.asarray(sharp,
                       dtype=np.float32) / (configs.batch_size * batch_id)

    print('psnr per frame: ' + str(np.mean(psnr)))
    print('ssim per frame: ' + str(np.mean(ssim)))
    print('fmae per frame: ' + str(np.mean(fmae)))
    print('sharpness per frame: ' + str(np.mean(sharp)))
Example #13
0
model = CausalLSTMStack(3, 2, args.num_hidden)  #filter_size, num_dims
decoder = torch.nn.Conv2d(16, 16, 1, 1)
# tmp = np.random.rand(8, 20, 16, 16, 16)

###  run iters

model.cuda()
decoder.cuda()
loss_fn = torch.nn.MSELoss()
optim = torch.optim.Adam(list(model.parameters()) + list(decoder.parameters()),
                         lr=args.lr)

for itr in range(10000):
    ims = train_input_handle.get_batch()
    ims = preprocess.reshape_patch(ims, args.patch_size)
    #print(ims.shape)# (8, 20, 16, 16, 16)
    ims = np.swapaxes(ims, 0, 1)
    h, c, m, z = [None] * 4
    #print(ims.shape)# (20, 8, 16, 16, 16)
    ims = np.swapaxes(ims, 2, 4)
    #print(ims.shape)
    #print(ims.shape)# (20, 8, 16, 16, 16)
    for t in range(args.seq_length):
        tmp = torch.Tensor(ims[t])
        tmp = tmp.cuda()
        h, c, m, z = model(tmp, h, c, m, z)

    z = decoder(h[-1].permute(0, -1, 1, 2))  #.permute(0,2,3,1)
    y = torch.Tensor(ims[-1])
    y = y.cuda()
def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)
    if tf.gfile.Exists(FLAGS.log_dir):
        tf.gfile.DeleteRecursively(FLAGS.log_dir)
    tf.gfile.MakeDirs(FLAGS.log_dir)

    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size, [FLAGS.img_height, FLAGS.img_width, FLAGS.img_channel], FLAGS.seq_length)

    print('Initializing models')
    model = Model()
    lr = FLAGS.lr

    delta = 0.0000125
    base = 0.99998
    eta = 1
    # eta = 0.5
#%%
    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        print('train get_batch:')
        ims, filename = train_input_handle.get_batch(False)
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        if itr < 80000:
            eta -= delta
        else:
            eta = 0.0
        random_flip = np.random.random_sample(
            (FLAGS.batch_size, FLAGS.seq_length))
        true_token = (random_flip < eta)
        ones = np.ones((FLAGS.img_height/FLAGS.patch_size,
                        FLAGS.img_width/FLAGS.patch_size,
                        FLAGS.patch_size**2*FLAGS.img_channel))
        zeros = np.zeros((FLAGS.img_height/FLAGS.patch_size,
                          FLAGS.img_width/FLAGS.patch_size,
                          FLAGS.patch_size**2*FLAGS.img_channel))
        
        mask_true = []
        for i in range(FLAGS.batch_size):
            for j in range(FLAGS.seq_length):
                # 0 2 4 6 8 10
                if (j % 2 == 0):
                    mask_true.append(ones)
                # if iteration bigger it will random mask 1 3 5 7 9
                else:
                    if true_token[i, j]:
                        mask_true.append(ones)
                    else:
                        mask_true.append(zeros)
                    
                    
#                 if j < FLAGS.input_length or FLAGS.seq_length-1-j < FLAGS.input_length:
#                     mask_true.append(ones)
#                 else:
#                     if true_token[i,j-10]:
#                         mask_true.append(ones)
#                     else:
#                         mask_true.append(zeros)
        mask_true = np.array(mask_true)
        mask_true = np.reshape(mask_true, (FLAGS.batch_size,
                                           FLAGS.seq_length,
                                           FLAGS.img_height/FLAGS.patch_size,
                                           FLAGS.img_width/FLAGS.patch_size,
                                           FLAGS.patch_size**2*FLAGS.img_channel))
        ###cost = model.train(ims, lr, mask_true)

        if FLAGS.reverse_input:
            ims_rev = ims[:,::-1]
            ###cost += model.train(ims_rev, lr, mask_true)
            ###cost = cost/2

        cost = model.train(ims, ims_rev, lr, mask_true, itr)
        #tf.summary.scalar('cost', cost)
        

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr))
            print('training loss: ' + str(cost))

        if itr % FLAGS.test_interval == 0:
            print('test...')
            test_input_handle.begin(do_shuffle = False)
            res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse,ssim,psnr,fmae,sharp= [],[],[],[],[]
            for i in range(FLAGS.seq_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)

            mask_true = np.ones((FLAGS.batch_size,
                                    FLAGS.seq_length,
                                    FLAGS.img_height,
                                    FLAGS.img_width,
                                    FLAGS.img_channel))
            for num_batch in range(FLAGS.batch_size):
                for num_seq in range(FLAGS.seq_length):
                    # 0 2 4 6 8 10 skip
                    if (num_seq % 2 == 0):
                        continue
                    # 1 3 5 7 9 replace random noise
                    else:
                        mask_true[num_batch,num_seq] = np.zeros((
                                FLAGS.img_height,
                                FLAGS.img_width,
                                FLAGS.img_channel))
                        
            mask_true = preprocess.reshape_patch(mask_true, FLAGS.patch_size)
            ###while(test_input_handle.no_batch_left() == False):
            while(batch_id <= 10):
                batch_id = batch_id + 1
                print('test get_batch:')
                test_ims, filename = test_input_handle.get_batch(False)
                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)

                if FLAGS.reverse_input:
                    test_ims_rev = test_dat[:,::-1]

                img_gen, ims_watch, ims_rev_watch = model.test(test_dat, test_ims_rev, mask_true, itr)

                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size)
                ims_watch = np.concatenate(ims_watch)
                ims_watch = preprocess.reshape_patch_back(ims_watch, FLAGS.patch_size)
                ims_rev_watch = np.concatenate(ims_rev_watch)
                ims_rev_watch = preprocess.reshape_patch_back(ims_rev_watch, FLAGS.patch_size)
                # MSE per frame
                for i in range(FLAGS.seq_length):
                    x = test_ims[:,i,:,:,0]
                    
                    # Predict only odd images
                    if FLAGS.gen_num == 5:
                        if (i % 2 == 1):
                            gx = img_gen[:,i//2,:,:,0]
                        else:
                            gx = test_ims[:,i,:,:,0]
                    # Predict 11 images
                    elif FLAGS.gen_num == 11:
                        if (i % 2 == 1):
                            gx = img_gen[:,i,:,:,0]
                        else:
                            gx = test_ims[:,i,:,:,0]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                    for b in range(FLAGS.batch_size):
                        sharp[i] += np.max(
                            cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b],3)))
                        score, _ = compare_ssim(pred_frm[b],real_frm[b],full=True)
                        ssim[i] += score

                # save prediction examples
                if batch_id <= 10:
                    path = os.path.join(res_path, str(filename))
                    os.mkdir(path)
                    for i in range(FLAGS.seq_length):
                        name = 'gt' + str(i+1) + '.png'
                        file_name = os.path.join(path, name)
                        img_gt = np.uint8(test_ims[0,i,:,:,:] * 255)
                        cv2.imwrite(file_name, img_gt)
                        
                    for i in range(FLAGS.seq_length):
                        name = 'pd' + str(i+1) + '.png'
                        file_name = os.path.join(path, name)
                        
                        # Predict only odd images
                        if FLAGS.gen_num == 5:
                            if (i % 2 == 1):
                                img_pd = img_gen[0,i//2,:,:,:]
                            else:
                                img_pd = test_ims[0,i,:,:,:]
                        # Predict 11 images
                        elif FLAGS.gen_num == 11:
                            img_pd = img_gen[0,i,:,:,:]
                        
                        img_pd = np.maximum(img_pd, 0)
                        img_pd = np.minimum(img_pd, 1)
                        img_pd = np.uint8(img_pd * 255)
                        cv2.imwrite(file_name, img_pd)
                        name = 'zwgt' + str(i+1) + '.png'
                        file_name = os.path.join(path, name)
                        img_zwgt = np.uint8(ims_watch[0,i,:,:,:] * 255)
                        cv2.imwrite(file_name, img_zwgt)
                        name = 'zwgtrev' + str(i+1) + '.png'
                        file_name = os.path.join(path, name)
                        #print('ims_rev_watch shape =',ims_rev_watch.shape)
                        zwgtrev = np.uint8(ims_rev_watch[0,i,:,:,:] * 255)
                        cv2.imwrite(file_name, zwgtrev)
                
                
                test_input_handle.next()
            avg_mse = avg_mse / (batch_id*FLAGS.batch_size)
            print('mse per seq: ' + str(avg_mse))
            for i in range(FLAGS.seq_length):
                print(img_mse[i] / (batch_id*FLAGS.batch_size))
            psnr = np.asarray(psnr, dtype=np.float32)/batch_id
            fmae = np.asarray(fmae, dtype=np.float32)/batch_id
            ssim = np.asarray(ssim, dtype=np.float32)/(FLAGS.batch_size*batch_id)
            sharp = np.asarray(sharp, dtype=np.float32)/(FLAGS.batch_size*batch_id)
            print('psnr per frame: ' + str(np.mean(psnr)))
            for i in range(FLAGS.seq_length):
                print(psnr[i])
            print('fmae per frame: ' + str(np.mean(fmae)))
            for i in range(FLAGS.seq_length):
                print(fmae[i])
            print('ssim per frame: ' + str(np.mean(ssim)))
            for i in range(FLAGS.seq_length):
                print(ssim[i])
            print('sharpness per frame: ' + str(np.mean(sharp)))
            for i in range(FLAGS.seq_length):
                print(sharp[i])

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        train_input_handle.next()
def main(argv=None):

    heading_dict = {1: 1, 2:85, 3: 170, 4: 255, 0:0}
    heading = FLAGS.heading
    FLAGS.save_dir += FLAGS.dataset_name + str(FLAGS.seq_length) + FLAGS.num_hidden + 'squash' + 'L1+L2+VALID' + 'multi-task'
    FLAGS.gen_frm_dir += FLAGS.dataset_name
    if not tf.io.gfile.exists(FLAGS.save_dir):
        # tf.io.gfile.rmtree(FLAGS.save_dir)
        tf.io.gfile.makedirs(FLAGS.save_dir)
    else:
        FLAGS.pretrained_model = FLAGS.save_dir
    if not tf.io.gfile.exists(FLAGS.gen_frm_dir):
        # tf.io.gfile.rmtree(FLAGS.gen_frm_dir)
        tf.io.gfile.makedirs(FLAGS.gen_frm_dir)

    train_data_paths = os.path.join(FLAGS.train_data_paths, FLAGS.dataset_name, FLAGS.dataset_name + '_training')
    valid_data_paths = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name, FLAGS.dataset_name + '_validation')
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, train_data_paths, valid_data_paths,
        FLAGS.batch_file, True, FLAGS.input_length, FLAGS.seq_length - FLAGS.input_length)

    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    heading_table = np.array([[0, 0], [-1, 1], [1, 1], [-1, -1], [1, -1]], dtype=np.float32)

    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    # dims = train_input_handle.dims
    print("Initializing models", flush=True)
    model = Model()
    lr = FLAGS.lr

    delta = 0.00002
    base = 0.99998
    eta = 1

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        imss = train_input_handle.get_batch()

        # # print("imss shape is ", imss.shape)
        tem_data = imss.copy()
        heading_image = imss[:, :, :, :, 2]*255
        heading_image = (heading_image // 85).astype(np.int8) + 1
        heading_image[tem_data[:, :, :, :, 2] == 0] = 0
        heading_image = heading_table[heading_image]

        speed_on_axis = np.expand_dims(imss[:, :, :, :, 1] / np.sqrt(2), axis=-1)
        imss = speed_on_axis * heading_image

        imss = preprocess.reshape_patch(imss, FLAGS.patch_size_width, FLAGS.patch_size_height)
        num_batches = imss.shape[0]
        for bi in range(0, num_batches, FLAGS.batch_size):
            ims = imss[bi:bi+FLAGS.batch_size]
            FLAGS.img_height = ims.shape[2]
            FLAGS.img_width = ims.shape[3]
            batch_size = ims.shape[0]
            if itr < 50000:
                eta -= delta
            else:
                eta = 0.0
            random_flip = np.random.random_sample(
                (batch_size, FLAGS.seq_length-FLAGS.input_length-1))
            true_token = (random_flip < eta)
            #true_token = (random_flip < pow(base,itr))
            ones = np.ones((FLAGS.img_height,
                            FLAGS.img_width,
                            int(FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel)))
            zeros = np.zeros((int(FLAGS.img_height),
                              int(FLAGS.img_width),
                              int(FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel)))
            mask_true = []
            for i in range(batch_size):
                for j in range(FLAGS.seq_length-FLAGS.input_length-1):
                    if true_token[i,j]:
                        mask_true.append(ones)
                    else:
                        mask_true.append(zeros)
            mask_true = np.array(mask_true)
            mask_true = np.reshape(mask_true, (batch_size,
                                               FLAGS.seq_length-FLAGS.input_length-1,
                                               int(FLAGS.img_height),
                                               int(FLAGS.img_width),
                                               int(FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel)))
            cost, _ = model.train(ims, lr, mask_true, batch_size)

            if FLAGS.reverse_input:
                ims_rev = ims[:,::-1]
                cost2, _ = model.train(ims_rev, lr, mask_true, batch_size)
                cost = (cost + cost2) / 2

            if itr % FLAGS.display_interval == 0:
                print('itr: ' + str(itr), flush=True)
                print('training loss: ' + str(cost), flush=True)

        train_input_handle.next()
        if itr % FLAGS.test_interval == 0:
            print('test...', flush=True)
            epsilon = 0.2
            batch_size = len(indicies)
            test_input_handle.begin(do_shuffle = False)
            # res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            # os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            gt_list = []
            pred_list = []
            pred_list_all = []
            pred_vec = []
            move_avg = []
            img_mse, ssim, psnr, fmae, sharp= [],[],[],[],[]
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            mask_true = np.zeros((batch_size,
                                  FLAGS.seq_length-FLAGS.input_length-1,
                                  FLAGS.img_height,
                                  FLAGS.img_width,
                                  FLAGS.patch_size_height*FLAGS.patch_size_width*FLAGS.img_channel))
            while(test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_test_batch(indicies)
                # get the ground truth
                gt_list.append(test_ims[:, FLAGS.input_length:, :, :, 1:])
                # cvt the heading to 0, 1, 2, 3, 4
                tem_data = test_ims.copy()
                heading_image = test_ims[:, :, :, :, 2] * 255
                heading_image = (heading_image // 85).astype(np.int8) + 1
                heading_image[tem_data[:, :, :, :, 2] == 0] = 0
                cvt_heading = heading_image.copy()

                # convert the data into speed vectors
                heading_selected = np.zeros_like(heading_image, np.int8)
                heading_selected[heading_image == heading] = heading
                heading_image = heading_selected
                heading_image = heading_table[heading_image]
                speed_on_axis = np.expand_dims(test_ims[:, :, :, :, 1] / np.sqrt(2), axis=-1)
                test_ims = speed_on_axis * heading_image

                # mavg filtered results
                mavg_results_all = cast_moving_avg(tem_data[:, :FLAGS.input_length, ...])
                mavg_results = np.zeros_like(mavg_results_all)
                # heading_image = np.expand_dims(heading_image, axis=-1)
                mavg_results[cvt_heading[:, FLAGS.input_length:, ...] == heading] = \
                    mavg_results_all[cvt_heading[:, FLAGS.input_length:, ...] == heading]
                move_avg.append(mavg_results)

                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size_width, FLAGS.patch_size_height)
                img_gen = model.test(test_dat, mask_true, batch_size)
                # concat outputs of different gpus along batch
                img_gen = np.concatenate(img_gen)
                # reshape the prediction has ndims=5
                img_gen = np.reshape(img_gen, (img_gen.shape[0], FLAGS.seq_length - FLAGS.input_length,
                                               FLAGS.img_height, FLAGS.img_width, -1))
                img_gen = preprocess.reshape_patch_back(img_gen, FLAGS.patch_size_width, FLAGS.patch_size_height)
                # print("Image Generates Shape is ", img_gen.shape)
                # MSE per frame

                img_gen_list = []
                img_gen_origin_list = []
                for i in range(FLAGS.seq_length - FLAGS.input_length):
                    x = tem_data[:,i + FLAGS.input_length,:,:, 1:]
                    gx = img_gen[:,i,:, :, :]

                    # print("img_gen shape is ", gx.shape)
                    val_results_speed = np.sqrt(gx[..., 0] ** 2 + gx[..., 1] ** 2)
                    # print("val speed: ", val_results_speed, flush=True)
                    val_results_heading = np.zeros_like(gx[..., 1])
                    val_results_heading[(gx[..., 0] > 0) & (gx[..., 1] > 0)] = 85.0 / 255.0
                    val_results_heading[(gx[..., 0] > 0) & (gx[..., 1] < 0)] = 255.0 / 255.0
                    val_results_heading[(gx[..., 0] < 0) & (gx[..., 1] < 0)] = 170.0 / 255.0
                    val_results_heading[(gx[..., 0] < 0) & (gx[..., 1] > 0)] = 1.0 / 255.0

                    gen_speed_heading = np.stack([val_results_speed, val_results_heading], axis=-1)
                    img_gen_origin_list.append(gen_speed_heading)

                    # Transformation according to moving average direction when mavg speed is small
                    val_results_heading[mavg_results[:, i, :, :, 1] < epsilon] = \
                        mavg_results[:, i, :, :, 2][mavg_results[:, i, :, :, 1] < epsilon]
                    gx = np.stack([val_results_speed, val_results_heading], axis=-1)
                    img_gen_list.append(gx)

                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                img_gen_list = np.stack(img_gen_list, axis=1)
                img_gen_origin_list = np.stack(img_gen_origin_list, axis=1)
                pred_list_all.append(img_gen_origin_list)
                pred_list.append(img_gen_list)
                pred_vec.append(img_gen)
                test_input_handle.next()

            avg_mse = avg_mse / (batch_id*batch_size*FLAGS.img_height *
                                 FLAGS.img_width * FLAGS.patch_size_height *
                                 FLAGS.patch_size_width * FLAGS.img_channel * len(img_mse))
            print('mse per seq: ' + str(avg_mse), flush=True)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] / (batch_id*batch_size*FLAGS.img_height *
                                 FLAGS.img_width * FLAGS.patch_size_height *
                                 FLAGS.patch_size_width * FLAGS.img_channel))

            gt_list_all = np.stack(gt_list, axis=0)
            # GT filtered to the direction required
            gt_list = np.zeros_like(gt_list_all)
            gt_list[gt_list_all[..., 1]*255 == heading_dict[heading]] = \
                gt_list_all[gt_list_all[..., 1]*255 == heading_dict[heading]]

            pred_list = np.stack(pred_list, axis=0)
            pred_list_all = np.stack(pred_list_all, axis=0)

            print("Evaluate on every pixels....")
            mse = masked_mse_np(pred_list, gt_list, null_val=np.nan)
            speed_mse = masked_mse_np(pred_list[..., 0], gt_list[..., 0], null_val=np.nan)
            direction_mse = masked_mse_np(pred_list[..., 1], gt_list[..., 1], null_val=np.nan)
            print("The output mse is ", mse)
            print("The speed mse is ", speed_mse)
            print("The direction mse is ", direction_mse)

            print("Evaluate on valid pixels for Transformation...")
            mse = masked_mse_np(pred_list, gt_list, null_val=0.0)
            speed_mse = masked_mse_np(pred_list[..., 0], gt_list[..., 0], null_val=0.0)
            direction_mse = masked_mse_np(pred_list[..., 1], gt_list[..., 1], null_val=0.0)
            print("The output mse is ", mse)
            print("The speed mse is ", speed_mse)
            print("The direction mse is ", direction_mse)

            print("Evaluate on valid pixels for No Transformation...")
            mse = masked_mse_np(pred_list_all, gt_list, null_val=0.0)
            speed_mse = masked_mse_np(pred_list_all[..., 0], gt_list[..., 0], null_val=0.0)
            direction_mse = masked_mse_np(pred_list_all[..., 1], gt_list[..., 1], null_val=0.0)
            print("The output mse is ", mse)
            print("The speed mse is ", speed_mse)
            print("The direction mse is ", direction_mse)

            print("Evaluate on valid pixels for MAVG...")
            # Evaluate on large gt speeds for direction
            move_avg = np.stack(move_avg, axis=0)
            mse = masked_mse_np(move_avg[..., 1:], gt_list, null_val=0.0)
            speed_mse = masked_mse_np(move_avg[..., 1], gt_list[..., 0], null_val=0.0)
            direction_mse = masked_mse_np(move_avg[..., 2], gt_list[..., 1], null_val=0.0)
            print("The output mse is ", mse)
            print("The speed mse is ", speed_mse)
            print("The direction mse is ", direction_mse)

            large_gt_speed = move_avg[..., 1] >= epsilon
            move_avg[..., 2][large_gt_speed] = pred_list_all[large_gt_speed, 1]
            direction_mse = masked_mse_np(move_avg[..., 2], gt_list[..., 1], null_val=0.0)
            print(f"The direction of combined mavg and large speed~({epsilon}) prediction is ", direction_mse)

            direction_mse = masked_mse_np(pred_list_all[large_gt_speed, 1], gt_list[large_gt_speed, 1], null_val=0.0)
            print("The direction mse on large speed gt is ", direction_mse)


        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)
Example #16
0
def main():
    # 파라미터 로드
    args = parse_args()

    # 리소스 로드
    if torch.cuda.is_available():
        device = torch.device(args.device)
    else:
        device = torch.device("cpu")

    model = MIM(args).to(device)
    print(model)
    print('The model is loaded!\n')

    # 데이터셋 로드
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        args.dataset_name,
        args.train_data_paths,
        args.valid_data_paths,
        args.batch_size * args.n_gpu,
        args.img_width,
        seq_length=args.total_length,
        is_training=True)  # n 64 64 1 로 나옴
    gen_images = None

    cell_state = [init_state(args) for i in range(4)]
    hidden_state = [init_state(args) for i in range(4)]
    cell_state_diff = [init_state(args) for i in range(3)]
    hidden_state_diff = [init_state(args) for i in range(3)]
    st_memory = init_state(args)
    conv_lstm_c = init_state(args)

    MIMB_ct_weight = nn.Parameter(
        torch.randn((args.num_hidden[0] * 2, args.img_height, args.img_width),
                    device=device))
    MIMB_oc_weight = nn.Parameter(
        torch.randn((args.num_hidden[0], args.img_height, args.img_width),
                    device=device))
    MIMN_ct_weight = nn.Parameter(
        torch.randn((args.num_hidden[0] * 2, args.img_height, args.img_width),
                    device=device))
    MIMN_oc_weight = nn.Parameter(
        torch.randn((args.num_hidden[0], args.img_height, args.img_width),
                    device=device))

    if args.pretrained_model:
        hidden_state, cell_state, hidden_state_diff, cell_state_diff, st_memory, conv_lstm_c, MIMB_ct_weight, \
        MIMB_oc_weight, MIMN_ct_weight, MIMN_oc_weight = loadVariables(args)
        model.load(args.pretrained_model)

    eta = args.sampling_start_value  # 1.0

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    # MSELoss = torch.nn.MSELoss()

    for itr in range(1, args.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)

        # input data 가져오기 및 reshape_patch 적용
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, args.patch_size)
        eta, real_input_flag = schedule_sampling(eta, itr, args)

        # ims tensor로 변경하고 gpu 설정해줌
        # 전방 전파 후 예측 영상과 gt 영상 생성
        ims_tensor = torch.tensor(ims, device=device)
        gen_images = model.forward(ims_tensor, real_input_flag, hidden_state,
                                   cell_state, hidden_state_diff,
                                   cell_state_diff, st_memory, conv_lstm_c,
                                   MIMB_oc_weight, MIMB_ct_weight,
                                   MIMN_oc_weight, MIMN_ct_weight)
        gt_ims = torch.tensor(ims[:, 1:], device=device)

        # optical flow loss 적용
        gen_diff, gt_diff = DOFLoss.dense_optical_flow_loss(
            gen_images, gt_ims, args.img_channel)
        optical_loss = DOFLoss.calc_optical_flow_loss(gen_diff, gt_diff,
                                                      args.device)
        MSE_loss = F.mse_loss(gen_images, gt_ims)

        # 역전파 적용
        optimizer.zero_grad()
        # loss = 0.8 * MSE_loss + 0.2 * optical_loss
        loss = MSE_loss
        loss.backward()
        optimizer.step()

        # 출력용 loss 저장
        loss_print = loss.detach_()
        flag = 1

        # 똑같은 computation graph 사용을 막기 위해 그래프와 학습 변수들을 분리해 줌
        del gen_images
        detachVariables(hidden_state, cell_state, hidden_state_diff,
                        cell_state_diff, st_memory, conv_lstm_c,
                        MIMB_ct_weight, MIMB_oc_weight, MIMN_ct_weight,
                        MIMN_oc_weight)

        ims_reverse = None
        if args.reverse_img:
            ims_reverse = ims[:, :, :, ::-1]
            ims_tensor = torch.tensor(ims_reverse.copy(), device=device)
            gen_images = model.forward(ims_tensor, real_input_flag,
                                       hidden_state, cell_state,
                                       hidden_state_diff, cell_state_diff,
                                       st_memory, conv_lstm_c, MIMB_oc_weight,
                                       MIMB_ct_weight, MIMN_oc_weight,
                                       MIMN_ct_weight)
            gt_ims = torch.tensor(ims_reverse[:, 1:].copy(), device=device)

            gen_diff, gt_diff = DOFLoss.dense_optical_flow_loss(
                gen_images, gt_ims, args.img_channel)
            optical_loss = DOFLoss.calc_optical_flow_loss(
                gen_diff, gt_diff, args.device)
            MSE_loss = F.mse_loss(gen_images, gt_ims)

            optimizer.zero_grad()
            # loss = 0.8 * MSE_loss + 0.2 * optical_loss
            loss = MSE_loss
            loss.backward()
            optimizer.step()

            loss_print += loss.detach_()
            flag += 1

            # 똑같은 computation graph 사용을 막기 위해 그래프와 학습 변수들을 분리해 줌
            del gen_images
            detachVariables(hidden_state, cell_state, hidden_state_diff,
                            cell_state_diff, st_memory, conv_lstm_c,
                            MIMB_ct_weight, MIMB_oc_weight, MIMN_ct_weight,
                            MIMN_oc_weight)

        if args.reverse_input:
            ims_rev = ims[:, ::-1]
            ims_tensor = torch.tensor(ims_rev.copy(), device=device)
            gen_images = model.forward(ims_tensor, real_input_flag,
                                       hidden_state, cell_state,
                                       hidden_state_diff, cell_state_diff,
                                       st_memory, conv_lstm_c, MIMB_oc_weight,
                                       MIMB_ct_weight, MIMN_oc_weight,
                                       MIMN_ct_weight)
            gt_ims = torch.tensor(ims_rev[:, 1:].copy(), device=device)

            gen_diff, gt_diff = DOFLoss.dense_optical_flow_loss(
                gen_images, gt_ims, args.img_channel)
            optical_loss = DOFLoss.calc_optical_flow_loss(
                gen_diff, gt_diff, args.device)
            MSE_loss = F.mse_loss(gen_images, gt_ims)

            optimizer.zero_grad()
            # loss = 0.8 * MSE_loss + 0.2 * optical_loss
            loss = MSE_loss
            loss.backward()
            optimizer.step()

            loss_print += loss.detach_()
            flag += 1

            # 똑같은 computation graph 사용을 막기 위해 그래프와 학습 변수들을 분리해 줌
            del gen_images
            detachVariables(hidden_state, cell_state, hidden_state_diff,
                            cell_state_diff, st_memory, conv_lstm_c,
                            MIMB_ct_weight, MIMB_oc_weight, MIMN_ct_weight,
                            MIMN_oc_weight)

            if args.reverse_img:
                ims_rev = ims_reverse[:, ::-1]
                ims_tensor = torch.tensor(ims_rev.copy(), device=device)
                gen_images = model.forward(ims_tensor, real_input_flag,
                                           hidden_state, cell_state,
                                           hidden_state_diff, cell_state_diff,
                                           st_memory, conv_lstm_c,
                                           MIMB_oc_weight, MIMB_ct_weight,
                                           MIMN_oc_weight, MIMN_ct_weight)
                gt_ims = torch.tensor(ims_rev[:, 1:].copy(), device=device)

                gen_diff, gt_diff = DOFLoss.dense_optical_flow_loss(
                    gen_images, gt_ims, args.img_channel)
                optical_loss = DOFLoss.calc_optical_flow_loss(
                    gen_diff, gt_diff, args.device)
                MSE_loss = F.mse_loss(gen_images, gt_ims)

                optimizer.zero_grad()
                # loss = 0.8 * MSE_loss + 0.2 * optical_loss
                loss = MSE_loss
                loss.backward()
                optimizer.step()

                loss_print += loss.detach_()
                flag += 1

                # 똑같은 computation graph 사용을 막기 위해 그래프와 학습 변수들을 분리해 줌
                del gen_images
                detachVariables(hidden_state, cell_state, hidden_state_diff,
                                cell_state_diff, st_memory, conv_lstm_c,
                                MIMB_ct_weight, MIMB_oc_weight, MIMN_ct_weight,
                                MIMN_oc_weight)

        # 전방전파 한 만큼 loss 나눠줌
        loss_print = loss_print.item() / flag

        # gen_diff_tensor = torch.tensor(gen_diff, device=args.device, requires_grad=True)
        # gt_diff_tensor = torch.tensor(gt_diff, device=args.device, requires_grad=True)
        #
        # # optical flow loss 벡터 구하는 식
        # diff = gt_diff_tensor - gen_diff_tensor
        # diff = torch.pow(diff, 2)
        # squared_distance = diff[0] + diff[1]
        # distance = torch.sqrt(squared_distance)
        # distance_sum = torch.mean(distance)

        # DOF_Mloss = F.mse_loss(gen_diff_tensor[0], gt_diff_tensor[0])
        # DOF_Dloss = F.mse_loss(gen_diff_tensor[1], gt_diff_tensor[1])

        # 얘 MSE로 하던가 Norm2 마할라노비스 등등으로 loss 구한다음에 MSE_loss 랑 더해주고 역전파 시키기
        # loss = 0.7 * MSE_loss + 0.25 * DOF_Mloss + 0.25 * DOF_Dloss
        # loss = trainer.trainer(model, ims, real_input_flag, args, itr, ims_reverse, device, optimizer, MSELoss)

        if itr % args.snapshot_interval == 0:
            # 모델 세이브 할때 detachVariable에 들어가는 애들 다 바꿔줘야 함
            saveVariables(args, hidden_state, cell_state, hidden_state_diff,
                          cell_state_diff, st_memory, conv_lstm_c,
                          MIMB_ct_weight, MIMB_oc_weight, MIMN_ct_weight,
                          MIMN_oc_weight, itr)
            model.save(itr)

        if itr % args.test_interval == 0:
            trainer.test(model, test_input_handle, args, itr, hidden_state,
                         cell_state, hidden_state_diff, cell_state_diff,
                         st_memory, conv_lstm_c, MIMB_oc_weight,
                         MIMB_ct_weight, MIMN_oc_weight, MIMN_ct_weight)

        if itr % args.display_interval == 0:
            print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                  'itr: ' + str(itr))
            print('training loss: ' + str(loss_print))

        train_input_handle.next()
Example #17
0
def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)

    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size, FLAGS.img_width, FLAGS.seq_length)

    print("Initializing models")
    model = Model()
    lr = FLAGS.lr

    # Prepare tensorboard logging
    logger = Logger(os.path.join(FLAGS.gen_frm_dir, 'board'), model.sess)
    logger.define_item("loss", Logger.Scalar, ())
    logger.define_item("lr", Logger.Scalar, ())
    logger.define_item("mse", Logger.Scalar, ())
    logger.define_item("psnr", Logger.Scalar, ())
    logger.define_item("fmae", Logger.Scalar, ())
    logger.define_item("ssim", Logger.Scalar, ())
    logger.define_item("sharp", Logger.Scalar, ())
    logger.define_item(
        "image",
        Logger.Image,
        (1, 2 * FLAGS.img_width, FLAGS.img_width, FLAGS.img_channel),
        dtype='uint8')

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        logger.add('lr', lr, itr)
        cost = model.train(ims, lr)
        if FLAGS.reverse_input:
            ims_rev = ims[:, ::-1]
            cost += model.train(ims_rev, lr, mask_true)
            cost = cost / 2
        logger.add('loss', cost, itr)

        if itr % FLAGS.display_interval == 0:
            print('itr: ' + str(itr))
            print('training loss: ' + str(cost))

        if itr % FLAGS.test_interval == 0:
            print('test...')
            test_input_handle.begin(do_shuffle=False)
            res_path = os.path.join(FLAGS.gen_frm_dir, str(itr))
            os.mkdir(res_path)
            avg_mse = 0
            batch_id = 0
            img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                img_mse.append(0)
                ssim.append(0)
                psnr.append(0)
                fmae.append(0)
                sharp.append(0)
            while (test_input_handle.no_batch_left() == False):
                batch_id = batch_id + 1
                test_ims = test_input_handle.get_batch()
                test_dat = preprocess.reshape_patch(test_ims, FLAGS.patch_size)
                img_gen = model.test(test_dat)

                # concat outputs of different gpus along batch
                # img_gen = np.concatenate(img_gen)
                img_gen = preprocess.reshape_patch_back(
                    img_gen[:, np.newaxis, :, :, :], FLAGS.patch_size)
                # MSE per frame
                for i in range(1):
                    x = test_ims[:, -1, :, :, 0]
                    gx = img_gen[:, :, :, 0]
                    fmae[i] += metrics.batch_mae_frame_float(gx, x)
                    gx = np.maximum(gx, 0)
                    gx = np.minimum(gx, 1)
                    mse = np.square(x - gx).sum()
                    img_mse[i] += mse
                    avg_mse += mse

                    real_frm = np.uint8(x * 255)
                    pred_frm = np.uint8(gx * 255)
                    psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
                    for b in range(FLAGS.batch_size):
                        sharp[i] += np.max(
                            cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
                        score, _ = compare_ssim(pred_frm[b],
                                                real_frm[b],
                                                full=True)
                        ssim[i] += score

                # save prediction examples
                if batch_id == 1:
                    sel = np.random.randint(FLAGS.batch_size)
                    img_seq_pd = img_gen[sel]
                    img_seq_gt = test_ims[sel, -1]
                    h, w = img_gen.shape[1:3]
                    out_img = np.zeros((1, h * 2, w * 1, FLAGS.img_channel),
                                       dtype='uint8')
                    for i, img_seq in enumerate([img_seq_gt, img_seq_pd]):
                        img = img_seq
                        img = np.maximum(img, 0)
                        img = np.uint8(img * 10)
                        img = np.minimum(img, 255)
                        out_img[0, (i * h):(i * h + h), :] = img
                    logger.add("image", out_img, itr)

                test_input_handle.next()
            avg_mse = avg_mse / (batch_id * FLAGS.batch_size)
            logger.add('mse', avg_mse, itr)
            print('mse per seq: ' + str(avg_mse))
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(img_mse[i] / (batch_id * FLAGS.batch_size))
            psnr = np.asarray(psnr, dtype=np.float32) / batch_id
            fmae = np.asarray(fmae, dtype=np.float32) / batch_id
            ssim = np.asarray(ssim, dtype=np.float32) / \
                (FLAGS.batch_size * batch_id)
            sharp = np.asarray(sharp, dtype=np.float32) / \
                (FLAGS.batch_size * batch_id)
            print('psnr per frame: ' + str(np.mean(psnr)))
            logger.add('psnr', np.mean(psnr), itr)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(psnr[i])
            print('fmae per frame: ' + str(np.mean(fmae)))
            logger.add('fmae', np.mean(fmae), itr)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(fmae[i])
            print('ssim per frame: ' + str(np.mean(ssim)))
            logger.add('ssim', np.mean(ssim), itr)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(ssim[i])
            print('sharpness per frame: ' + str(np.mean(sharp)))
            logger.add('sharp', np.mean(sharp), itr)
            for i in range(FLAGS.seq_length - FLAGS.input_length):
                print(sharp[i])

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        train_input_handle.next()
Example #18
0
def train_wrapper(model):
    # logging
    log_format = '%(asctime)s %(message)s'  #标准格式化:时间 + 信息
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.loss_dir, 'train.log'))  #创建日志
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    # record all the print content to txt file
    logger.make_print_to_file(args.loss_dir)

    if args.pretrained_model:
        model.load(args.pretrained_model)

    # load data
    train_inputs = radar_dataloader(
        args.train_data_paths,
        sample_shape=(args.total_length, 1, args.img_width,
                      args.img_width),  #(20 , 1, 140,140)
        input_len=args.input_length)
    test_inputs = radar_dataloader(args.valid_data_paths,
                                   sample_shape=(args.total_length, 1,
                                                 args.img_width,
                                                 args.img_width),
                                   input_len=args.input_length)
    train_loaders = torch.utils.data.DataLoader(train_inputs,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=0,
                                                pin_memory=True)
    test_loaders = torch.utils.data.DataLoader(test_inputs,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=0,
                                               pin_memory=True,
                                               drop_last=True)

    # schedule sampling
    eta = args.sampling_start_value

    ### save traning loss and test loss
    train_loss = []
    test_loss = []
    test_ssim = []
    test_psnr = []
    test_fmae = []
    test_sharp = []
    test_iter = []

    llr = args.lr
    model.optimizer = adjust_learning_rate(model.optimizer, llr)  #调整学习率

    for epoch in tqdm(range(0, args.max_epoch + 1)):  #(0 , 61)
        losses = AverageMeter()

        if epoch % args.adjust_interval == 0 and epoch > 0:  #每 adjust_interval步长调整学习率
            llr = llr * args.adjust_rate  #学习率 *= 0.5
            model.optimizer = adjust_learning_rate(model.optimizer, llr)

        for ind, ims in enumerate(train_loaders):
            eta, real_input_flag = schedule_sampling(eta, epoch)
            ims = preprocess.reshape_patch(
                ims, args.patch_size)  #ims(4 , 20 , 140//4 , 140//4 , 4*4*1)
            tr_loss = trainer.train(model, ims, real_input_flag, args, epoch)
            train_loss.append(tr_loss)

            losses.update(tr_loss)

            if ind % args.display_interval == 0:
                logging.info('[{0}][{1}]\t'
                             'lr: {lr:.5f}\t'
                             'loss: {loss.val:.6f} ({loss.avg:.6f})'.format(
                                 epoch,
                                 ind,
                                 lr=model.optimizer.param_groups[-1]['lr'],
                                 loss=losses))

            torch.cuda.empty_cache()

        # plot figure to observe the losses
        x = range(len(train_loss))
        plt.figure(1)
        plt.title("this is losses of training")
        plt.plot(x, train_loss, label='loss')
        plt.legend()
        plt.savefig(args.loss_dir + '/train_loss.png')
        plt.close(1)
        # next

        if epoch % args.snapshot_interval == 0 and epoch > 0:  #每snapshot_interval步保存一次模型
            model.save(epoch)

        if epoch % args.test_interval == 0 and epoch > 0:
            with torch.no_grad():
                avg_mse, ssim, psnr, fmae, sharp = trainer.test(
                    model, test_loaders, args, epoch)
            test_iter.append(epoch)
            test_loss.append(avg_mse)
            test_ssim.append(ssim)
            test_psnr.append(psnr)
            test_fmae.append(fmae)
            test_sharp.append(sharp)

            # plot figure to observe the losses
            x = range(len(test_loss))
            plt.figure(1)
            plt.title("this is losses of validation")
            plt.plot(x, test_loss, label='loss')
            plt.legend()
            plt.savefig(args.loss_dir + '/valid_loss.png')
            plt.close(1)
            # next

        if epoch % args.loss_interval == 0 and epoch > 0:
            fileName = "/loss epoch{}".format(
                epoch) + datetime.datetime.now().strftime('date:' + '%Y_%m_%d')
            np.savez_compressed(args.loss_dir + fileName,
                                train_loss=np.array(train_loss),
                                test_iter=np.array(test_iter),
                                test_loss=np.array(test_loss),
                                test_ssim=np.array(test_ssim),
                                test_psnr=np.array(test_psnr),
                                test_fmae=np.array(test_fmae),
                                test_sharp=np.array(test_sharp))

    fileName = "/loss all " + datetime.datetime.now().strftime('date:' +
                                                               '%Y_%m_%d')
    np.savez_compressed(args.loss_dir + fileName,
                        train_loss=np.array(train_loss),
                        test_iter=np.array(test_iter),
                        test_loss=np.array(test_loss),
                        test_ssim=np.array(test_ssim),
                        test_psnr=np.array(test_psnr),
                        test_fmae=np.array(test_fmae),
                        test_sharp=np.array(test_sharp))
Example #19
0
def main(argv=None):

    # FLAGS.save_dir += FLAGS.dataset_name
    # FLAGS.gen_frm_dir += FLAGS.dataset_name
    # if tf.io.gfile.exists(FLAGS.save_dir):
    #     tf.io.gfile.rmtree(FLAGS.save_dir)
    # tf.io.gfile.makedirs(FLAGS.save_dir)
    # if tf.io.gfile.exists(FLAGS.gen_frm_dir):
    #     tf.io.gfile.rmtree(FLAGS.gen_frm_dir)
    # tf.io.gfile.makedirs(FLAGS.gen_frm_dir)

    FLAGS.save_dir += FLAGS.dataset_name + str(
        FLAGS.seq_length) + FLAGS.num_hidden
    print(FLAGS.save_dir)
    # FLAGS.best_model = FLAGS.save_dir + '/best.ckpt'
    FLAGS.best_model = FLAGS.save_dir + f'/best_channels{FLAGS.img_channel}.ckpt'
    # FLAGS.best_model = FLAGS.save_dir + f'/best_channels{FLAGS.img_channel}_weighted.ckpt'
    # FLAGS.save_dir += FLAGS.dataset_name
    FLAGS.pretrained_model = FLAGS.save_dir

    process_data_dir = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name,
                                    'process_0.5')
    node_pos_file_2in1 = os.path.join(process_data_dir, 'node_pos_0.5.npy')
    node_pos = np.load(node_pos_file_2in1)

    test_data_paths = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name,
                                   FLAGS.dataset_name + '_' + FLAGS.mode)
    sub_files = preprocess.list_filenames(test_data_paths, [])

    output_path = f'./Results/predrnn/t{FLAGS.test_time}_{FLAGS.mode}/'
    # output_path = f'./Results/predrnn/t14/'
    preprocess.create_directory_structure(output_path)
    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    print("Initializing models", flush=True)
    model = Model()

    step = 6
    se_total = 0.
    se_1 = 0.
    se_2 = 0.
    se_3 = 0.
    gt_list = []
    pred_list = []
    mavg_list = []
    for f in sub_files:
        with h5py.File(os.path.join(test_data_paths, f), 'r') as h5_file:
            data = h5_file['array'][()]
            # Query the Moving Average Data
            prev_data = [data[y - step:y] for y in indicies]
            prev_data = np.stack(prev_data, axis=0)
            # type casting
            # prev_data = prev_data.astype(np.float32) / 255.0
            # mavg_pred = cast_moving_avg(prev_data)
            # mavg_list.append(mavg_pred)

            # get relevant training data pieces
            data = [
                data[y - FLAGS.input_length:y + FLAGS.seq_length -
                     FLAGS.input_length] for y in indicies
            ]
            data = np.stack(data, axis=0)
            # select the data channel as wished
            data = data[..., :FLAGS.img_channel]

            # all validation data is applied
            # data = np.reshape(data,(-1, FLAGS.seq_length,
            #                     FLAGS.img_height*FLAGS.patch_size_height, FLAGS.img_width*FLAGS.patch_size_width, 3))
            # type casting
            test_dat = data.astype(np.float32) / 255.0
            test_dat = preprocess.reshape_patch(test_dat,
                                                FLAGS.patch_size_width,
                                                FLAGS.patch_size_height)
            batch_size = data.shape[0]
            mask_true = np.zeros(
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_height, FLAGS.img_width, FLAGS.patch_size_height *
                 FLAGS.patch_size_width * FLAGS.img_channel))
            img_gen = model.test(test_dat, mask_true, batch_size)
            # concat outputs of different gpus along batch
            # img_gen = np.concatenate(img_gen)
            img_gen = img_gen[0]
            img_gen = np.maximum(img_gen, 0)
            img_gen = np.minimum(img_gen, 1)
            img_gen = preprocess.reshape_patch_back(img_gen,
                                                    FLAGS.patch_size_width,
                                                    FLAGS.patch_size_height)
            img_gt = data[:, FLAGS.input_length:, ...].astype(
                np.float32) / 255.0

            gt_list.append(img_gt)
            pred_list.append(img_gen)
            se_total += np.sum((img_gt - img_gen)**2)

            se_1 += np.sum((img_gt[..., 0] - img_gen[..., 0])**2)
            se_2 += np.sum((img_gt[..., 1] - img_gen[..., 1])**2)
            # se_3 += np.sum((img_gt[..., 2] - img_gen[..., 2]) ** 2)

            img_gen = np.uint8(img_gen * 255)
            outfile = os.path.join(output_path, FLAGS.dataset_name,
                                   FLAGS.dataset_name + '_test', f)
            preprocess.write_data(img_gen, outfile)

    # mse = se_total / (len(indicies) * len(sub_files) * 495 * 436 * 3 * 3)
    #
    # mse1 = se_1 / (len(indicies) * len(sub_files) * 495 * 436 * 3)
    # mse2 = se_2 / (len(indicies) * len(sub_files) * 495 * 436 * 3)
    # # mse3 = se_3 / (len(indicies) * len(sub_files) * 495 * 436 * 3)
    # print(FLAGS.dataset_name)
    # print("MSE: ", mse)
    # print("MSE_vol: ", mse1)
    # print("MSE_sp: ", mse2)
    # # print("MSE_hd: ", mse3)
    #
    # pred_list = np.stack(pred_list, axis=0)
    # gt_list = np.stack(gt_list, axis=0)
    # mavg_list = np.stack(mavg_list, axis=0)
    #
    # array_mse = masked_mse_np(mavg_list, gt_list, np.nan)
    # print(f'MAVG {step} MSE: ', array_mse)
    #
    # # adapt pred on non_zero mavg pred only
    # pred_list_copy = np.zeros_like(pred_list)
    # pred_list_copy[mavg_list > 0] = pred_list[mavg_list > 0]
    #
    # array_mse = masked_mse_np(pred_list_copy, gt_list, np.nan)
    # print(f'PRED+MAVG {step} MSE: ', array_mse)
    #
    # # Evaluate on nodes
    # # Check MSE on node_pos
    # img_gt_node = gt_list[:, :, :, node_pos[:, 0], node_pos[:, 1], :].astype(np.float32)
    # img_gen_node = pred_list[:, :, :, node_pos[:, 0], node_pos[:, 1], :].astype(np.float32)
    # mse_node_all = masked_mse_np(img_gen_node, img_gt_node, np.nan)
    # mse_node_volume = masked_mse_np(img_gen_node[..., 0], img_gt_node[..., 0], np.nan)
    # mse_node_speed = masked_mse_np(img_gen_node[..., 1], img_gt_node[..., 1], np.nan)
    # mse_node_direction = masked_mse_np(img_gen_node[..., 2], img_gt_node[..., 2], np.nan)
    # print("Results on Node Pos: ")
    # print("MSE: ", mse_node_all)
    # print("Volume mse: ", mse_node_volume)
    # print("Speed mse: ", mse_node_speed)
    # print("Direction mse: ", mse_node_direction)
    #
    # print("Evaluating on Condensed Graph....")
    # seq_length = np.shape(gt_list)[2]
    # img_height = np.shape(gt_list)[3]
    # img_width = np.shape(gt_list)[4]
    # num_channels = np.shape(gt_list)[5]
    # gt_list = np.reshape(gt_list, [-1, seq_length,
    #                             int(img_height / FLAGS.patch_size_height), FLAGS.patch_size_height,
    #                             int(img_width / FLAGS.patch_size_width), FLAGS.patch_size_width,
    #                             num_channels])
    # gt_list = np.transpose(gt_list, [0, 1, 2, 4, 3, 5, 6])
    #
    # pred_list = np.reshape(pred_list, [-1, seq_length,
    #                                int(img_height / FLAGS.patch_size_height), FLAGS.patch_size_height,
    #                                int(img_width / FLAGS.patch_size_width), FLAGS.patch_size_width,
    #                                num_channels])
    # pred_list = np.transpose(pred_list, [0, 1, 2, 4, 3, 5, 6])
    #
    # node_pos = preprocess.construct_road_network_from_grid_condense(FLAGS.patch_size_height, FLAGS.patch_size_width,
    #                                                                 test_data_paths)
    #
    # img_gt_node = gt_list[:, :, node_pos[:, 0], node_pos[:, 1], ...].astype(np.float32)
    # img_gen_node = pred_list[:, :, node_pos[:, 0], node_pos[:, 1], ...].astype(np.float32)
    # mse_node_all = masked_mse_np(img_gen_node, img_gt_node, np.nan)
    # mse_node_volume = masked_mse_np(img_gen_node[..., 0], img_gt_node[..., 0], np.nan)
    # mse_node_speed = masked_mse_np(img_gen_node[..., 1], img_gt_node[..., 1], np.nan)
    # mse_node_direction = masked_mse_np(img_gen_node[..., 2], img_gt_node[..., 2], np.nan)
    # print("MSE: ", mse_node_all)
    # print("Volume mse: ", mse_node_volume)
    # print("Speed mse: ", mse_node_speed)
    # print("Direction mse: ", mse_node_direction)

    print("Finished...")
Example #20
0
def main(argv=None):

    # FLAGS.save_dir += FLAGS.dataset_name + str(FLAGS.seq_length) + FLAGS.num_hidden + 'squash'
    heading_dict = {1: 1, 2: 85, 3: 170, 4: 255, 0: 0}
    heading = FLAGS.heading
    loss_func = FLAGS.loss_func
    FLAGS.save_dir += FLAGS.dataset_name + str(
        FLAGS.seq_length
    ) + FLAGS.num_hidden + 'squash' + FLAGS.loss_func + str(heading)
    FLAGS.pretrained_model = FLAGS.save_dir

    test_data_paths = os.path.join(FLAGS.valid_data_paths, FLAGS.dataset_name,
                                   FLAGS.dataset_name + '_' + FLAGS.mode)
    sub_files = preprocess.list_filenames(test_data_paths, [])

    output_path = f'./Results/predrnn/t{FLAGS.test_time}_{FLAGS.mode}/{FLAGS.loss_func}'
    # output_path = f'./Results/predrnn/t14/'
    preprocess.create_directory_structure(output_path)

    # The following indicies are the start indicies of the 3 images to predict in the 288 time bins (0 to 287)
    # in each daily test file. These are time zone dependent. Berlin lies in UTC+2 whereas Istanbul and Moscow
    # lie in UTC+3.
    utcPlus2 = [30, 69, 126, 186, 234]
    utcPlus3 = [57, 114, 174, 222, 258]
    heading_table = np.array([[0, 0], [-1, 1], [1, 1], [-1, -1], [1, -1]],
                             dtype=np.float32)

    indicies = utcPlus3
    if FLAGS.dataset_name == 'Berlin':
        indicies = utcPlus2

    # dims = train_input_handle.dims
    print("Initializing models", flush=True)
    model = Model()

    avg_mse = 0
    gt_list = []
    pred_list = []
    pred_list_all = []
    pred_vec = []
    move_avg = []

    for f in sub_files:
        with h5py.File(os.path.join(test_data_paths, f), 'r') as h5_file:
            data = h5_file['array'][()]
            # get relevant training data pieces
            data = [
                data[y - FLAGS.input_length:y + FLAGS.seq_length -
                     FLAGS.input_length] for y in indicies
            ]
            test_ims = np.stack(data, axis=0)
            batch_size = len(indicies)
            gt_list.append(test_ims[:, FLAGS.input_length:, :, :, 1:])

            tem_data = test_ims.copy()
            heading_image = test_ims[:, :, :, :, 2]
            heading_image = (heading_image // 85).astype(np.int8) + 1
            heading_image[tem_data[:, :, :, :, 2] == 0] = 0

            # convert the data into speed vectors
            heading_selected = np.zeros_like(heading_image, np.int8)
            heading_selected[heading_image == heading] = heading
            heading_image = heading_selected
            heading_image = heading_table[heading_image]
            speed_on_axis = np.expand_dims(
                test_ims[:, :, :, :, 1].astype(np.float32) / 255.0 /
                np.sqrt(2),
                axis=-1)
            test_ims = speed_on_axis * heading_image

            test_dat = preprocess.reshape_patch(test_ims,
                                                FLAGS.patch_size_width,
                                                FLAGS.patch_size_height)

            mask_true = np.zeros(
                (batch_size, FLAGS.seq_length - FLAGS.input_length - 1,
                 FLAGS.img_height, FLAGS.img_width, FLAGS.patch_size_height *
                 FLAGS.patch_size_width * FLAGS.img_channel))

            img_gen = model.test(test_dat, mask_true, batch_size)

            # concat outputs of different gpus along batch
            img_gen = np.concatenate(img_gen)
            img_gen = preprocess.reshape_patch_back(img_gen,
                                                    FLAGS.patch_size_width,
                                                    FLAGS.patch_size_height)

            outfile = os.path.join(output_path, FLAGS.dataset_name,
                                   FLAGS.dataset_name + '_test',
                                   f'{heading}' + f)
            preprocess.write_data(img_gen, outfile)