Ejemplo n.º 1
0
def run_benchmark(hko_factory,
                  context,
                  encoder_net,
                  forecaster_net,
                  loss_net=None,
                  sample_num=1,
                  finetune=False,
                  mode="fixed",
                  save_dir="hko7_rnn",
                  pd_path=cfg.HKO_PD.RAINY_TEST):
    """Run the HKO7 Benchmark given the training sequences
    
    Parameters
    ----------
    hko_factory :
    context : mx.ctx
    encoder_net : MyModule
    forecaster_net : MyModule
    loss_net : MyModule
    sample_num : int
    finetune : bool
    mode : str
    save_dir : str
    pd_path : str

    Returns
    -------

    """
    logging.info("Begin Evaluation, mode=%s, finetune=%s, sample_num=%d,"
                 " results will be saved to %s" %
                 (mode, str(finetune), sample_num, save_dir))
    if finetune:
        logging.info(str(cfg.MODEL.TEST.ONLINE))
    env = HKOBenchmarkEnv(pd_path=pd_path, save_dir=save_dir, mode=mode)
    states = EncoderForecasterStates(factory=hko_factory, ctx=context)

    stored_data = []
    if not cfg.MODEL.TEST.DISABLE_TBPTT:
        stored_states = []
    stored_prediction = []
    counter = 0
    finetune_iter = 0
    while not env.done:
        if finetune:
            if len(stored_data) >= 5:
                data_in = stored_data[0]
                data_gt = np.concatenate(stored_data[1:], axis=0)
                gt_mask = precompute_mask(data_gt)
                init_states = EncoderForecasterStates(factory=hko_factory,
                                                      ctx=context)
                if not cfg.MODEL.TEST.DISABLE_TBPTT:
                    init_states.update(states_nd=[
                        nd.array(ele, ctx=context) for ele in stored_states[0]
                    ])
                weights = get_balancing_weights_numba(
                    data=data_gt,
                    mask=gt_mask,
                    base_balancing_weights=cfg.HKO.EVALUATION.
                    BALANCING_WEIGHTS,
                    thresholds=env._all_eval._thresholds)
                weighted_mse = (weights *
                                np.square(stored_prediction[0] - data_gt)).sum(
                                    axis=(2, 3, 4))
                mean_weighted_mse = weighted_mse.mean()
                print("mean_weighted_mse = %g" % mean_weighted_mse)
                if mean_weighted_mse > cfg.MODEL.TEST.ONLINE.FINETUNE_MIN_MSE:
                    _, loss_dict =\
                        train_step(batch_size=1,
                                   encoder_net=encoder_net,
                                   forecaster_net=forecaster_net,
                                   loss_net=loss_net,
                                   init_states=init_states,
                                   data_nd=nd.array(data_in, ctx=context),
                                   gt_nd=nd.array(data_gt, ctx=context),
                                   mask_nd=nd.array(gt_mask, ctx=context),
                                   iter_id=finetune_iter)
                    finetune_iter += 1
                stored_data = stored_data[1:]
                stored_prediction = stored_prediction[1:]
                if not cfg.MODEL.TEST.DISABLE_TBPTT:
                    stored_states = stored_states[1:]
        if mode == "fixed" or cfg.MODEL.TEST.DISABLE_TBPTT:
            states.reset_all()
        in_frame_dat, in_datetime_clips, out_datetime_clips, begin_new_episode, need_upload_prediction =\
            env.get_observation(batch_size=1)
        if finetune:
            if begin_new_episode:
                stored_data = [in_frame_dat]
                stored_prediction = []
                if not cfg.MODEL.TEST.DISABLE_TBPTT:
                    stored_states = [[
                        ele.asnumpy() for ele in states.get_encoder_states()
                    ]]
            else:
                stored_data.append(in_frame_dat)
                if not cfg.MODEL.TEST.DISABLE_TBPTT:
                    stored_states.append(
                        [ele.asnumpy() for ele in states.get_encoder_states()])
        in_frame_nd = nd.array(in_frame_dat, ctx=context)
        encoder_net.forward(
            is_train=False,
            data_batch=mx.io.DataBatch(data=[in_frame_nd] +
                                       states.get_encoder_states()))
        outputs = encoder_net.get_outputs()
        states.update(states_nd=outputs)
        if need_upload_prediction:
            counter += 1
            if cfg.MODEL.OUT_TYPE == "direct":
                forecaster_net.forward(is_train=False,
                                       data_batch=mx.io.DataBatch(
                                           data=states.get_forecaster_state()))
                pred_nd = forecaster_net.get_outputs()[0]
            else:
                forecaster_net.forward(
                    is_train=False,
                    data_batch=mx.io.DataBatch(
                        data=states.get_forecaster_state() +
                        [in_frame_nd[in_frame_nd.shape[0] - 1]]))
                pred_nd = forecaster_net.get_outputs()[0]
                flow_nd = forecaster_net.get_outputs()[1]
            pred_nd = nd.clip(pred_nd, a_min=0, a_max=1)
            env.upload_prediction(prediction=pred_nd.asnumpy())
            if finetune:
                stored_prediction.append(pred_nd.asnumpy())
    env.save_eval()
Ejemplo n.º 2
0
def hko_benchmark(ctx,
                  generator_net,
                  loss_net,
                  sample_num,
                  finetune=False,
                  mode="fixed",
                  save_dir="hko7_rnn",
                  pd_path=cfg.HKO_PD.RAINY_TEST):
    """Run the HKO7 Benchmark given the training sequences

    Args:
        ctx
        generator_net
        sample_num
        save_dir
        pd_path
    """
    logging.info("Begin Evaluation, sample_num=%d,"
                 " results will be saved to %s" % (sample_num, save_dir))
    if finetune:
        logging.info(str(cfg.MODEL.TEST.ONLINE))
    env = HKOBenchmarkEnv(pd_path=pd_path, save_dir=save_dir, mode=mode)

    if finetune:
        assert (mode == "online")
        data_buffer = []
        stored_prediction = []
        finetune_iter = 0

    context_nd = None

    i = 0
    while not env.done:
        logging.info("Iter {} of evaluation.".format(i))
        i += 1
        if finetune:
            if len(data_buffer) >= 5:
                context_np = data_buffer[0]  # HKO.BENCHMARK.IN_LEN frames
                gt_np = np.concatenate(data_buffer[1:], axis=0)
                gt_np = gt_np[:cfg.HKO.BENCHMARK.OUT_LEN]

                mask_np = precompute_mask(gt_np)

                weights = get_balancing_weights_numba(
                    data=gt_np,
                    mask=mask_np,
                    base_balancing_weights=cfg.HKO.EVALUATION.
                    BALANCING_WEIGHTS,
                    thresholds=env._all_eval._thresholds)
                weighted_mse = (weights *
                                np.square(stored_prediction[0] - gt_np)).sum(
                                    axis=(2, 3, 4))
                mean_weighted_mse = weighted_mse.mean()
                print("mean_weighted_mse = %g" % mean_weighted_mse)

                if mean_weighted_mse > cfg.MODEL.TEST.ONLINE.FINETUNE_MIN_MSE:
                    context_nd = mx.nd.array(context_np, ctx=ctx)
                    context_nd = mx.nd.transpose(
                        context_nd, axes=(1, 2, 0, 3, 4))
                    gt_nd = mx.nd.array(gt_np, ctx=ctx)
                    gt_nd = mx.nd.transpose(gt_nd, axes=(1, 2, 0, 3, 4))
                    mask_nd = mx.nd.array(mask_np, ctx=ctx)
                    mask_nd = mx.nd.transpose(mask_nd, axes=(1, 2, 0, 3, 4))

                    train_step(
                        generator_net=generator_net,
                        loss_net=loss_net,
                        context_nd=context_nd,
                        gt_nd=gt_nd,
                        mask_nd=mask_nd)

                    finetune_iter += 1

                del data_buffer[0]
                del stored_prediction[0]

        if mode == "online":
            context_np, in_datetime_clips, out_datetime_clips,\
                begin_new_episode, need_upload_prediction = env.get_observation(
                    batch_size=1)
            context_np = np.repeat(
                context_np, cfg.MODEL.TRAIN.BATCH_SIZE, axis=1)
            orig_size = 1

        elif mode == "fixed":
            context_np, in_datetime_clips, out_datetime_clips,\
                begin_new_episode, need_upload_prediction = env.get_observation(
                    batch_size=cfg.MODEL.TRAIN.BATCH_SIZE)
            context_nd = mx.nd.array(context_np, ctx=ctx)
            context_nd = mx.nd.transpose(context_nd, axes=(1, 2, 0, 3, 4))

            # Pad context_nd up to batch size if needed
            orig_size = context_nd.shape[0]
            while context_nd.shape[0] < cfg.MODEL.TRAIN.BATCH_SIZE:
                context_nd = mx.nd.concat(
                    context_nd, context_nd[0:1], num_args=2, dim=0)
        else:
            raise NotImplementedError

        if finetune:
            if begin_new_episode:
                data_buffer = [context_np]
                prediction_buffer = []
            else:
                data_buffer.append(context_np)

        if mode != "fixed":
            context_nd = mx.nd.array(context_np, ctx=ctx)
            context_nd = mx.nd.transpose(context_nd, axes=(1, 2, 0, 3, 4))
        generator_net.forward(
            is_train=False, data_batch=mx.io.DataBatch(data=[context_nd]))

        if need_upload_prediction:
            generator_outputs = dict(
                zip(generator_net.output_names, generator_net.get_outputs()))
            pred_nd = generator_outputs["pred_output"]

            pred_nd = pred_nd[0:orig_size]

            pred_nd = mx.nd.clip(pred_nd, a_min=0, a_max=1)
            pred_nd = mx.nd.transpose(pred_nd, axes=(2, 0, 1, 3, 4))

            env.upload_prediction(prediction=pred_nd.asnumpy())

            if finetune:
                stored_prediction.append(pred_nd.asnumpy())

    env.save_eval()
Ejemplo n.º 3
0
def run(pd_path=cfg.HKO_PD.RAINY_TEST,
        mode="fixed",
        interp_type="bilinear",
        nonlinear_transform=True):
    transformer = NonLinearRoverTransform()
    flow_factory = VarFlowFactory(max_level=6,
                                  start_level=0,
                                  n1=2,
                                  n2=2,
                                  rho=1.5,
                                  alpha=2000,
                                  sigma=4.5)
    assert interp_type == "bilinear", "Nearest interpolation is implemented in CPU and is too slow." \
                                      " We only support bilinear interpolation for rover."
    if nonlinear_transform:
        base_dir = os.path.join('hko7_benchmark', 'rover-nonlinear')
    else:
        base_dir = os.path.join('hko7_benchmark', 'rover-linear')
    logging_config(base_dir)
    batch_size = 1
    env = HKOBenchmarkEnv(pd_path=pd_path, save_dir=base_dir, mode=mode)
    counter = 0
    while not env.done:
        in_frame_dat, in_datetime_clips, out_datetime_clips, \
        begin_new_episode, need_upload_prediction = \
            env.get_observation(batch_size=batch_size)
        if need_upload_prediction:
            counter += 1
            prediction = np.zeros(shape=(cfg.HKO.BENCHMARK.OUT_LEN, ) +
                                  in_frame_dat.shape[1:],
                                  dtype=np.float32)
            I1 = in_frame_dat[-2, :, 0, :, :]
            I2 = in_frame_dat[-1, :, 0, :, :]
            mask_I1 = precompute_mask(I1)
            mask_I2 = precompute_mask(I2)
            I1 = I1 * mask_I1
            I2 = I2 * mask_I2
            if nonlinear_transform:
                I1 = transformer.transform(I1)
                I2 = transformer.transform(I2)
            flow = flow_factory.batch_calc_flow(I1=I1, I2=I2)
            if interp_type == "bilinear":
                init_im = nd.array(I2.reshape(
                    (I2.shape[0], 1, I2.shape[1], I2.shape[2])),
                                   ctx=mx.gpu())
                nd_flow = nd.array(np.concatenate(
                    (flow[:, :1, :, :], -flow[:, 1:, :, :]), axis=1),
                                   ctx=mx.gpu())
                nd_pred_im = nd.zeros(shape=prediction.shape)
                for i in range(cfg.HKO.BENCHMARK.OUT_LEN):
                    new_im = nd_advection(init_im, flow=nd_flow)
                    nd_pred_im[i][:] = new_im
                    init_im[:] = new_im
                prediction = nd_pred_im.asnumpy()
            elif interp_type == "nearest":
                init_im = I2.reshape(
                    (I2.shape[0], 1, I2.shape[1], I2.shape[2]))
                for i in range(cfg.HKO.BENCHMARK.OUT_LEN):
                    new_im = nearest_neighbor_advection(init_im, flow)
                    prediction[i, ...] = new_im
                    init_im = new_im
            if nonlinear_transform:
                prediction = transformer.rev_transform(prediction)
            env.upload_prediction(prediction=prediction)
            if counter % 10 == 0:
                save_hko_gif(in_frame_dat[:, 0, 0, :, :],
                             save_path=os.path.join(base_dir, 'in.gif'))
                save_hko_gif(prediction[:, 0, 0, :, :],
                             save_path=os.path.join(base_dir, 'pred.gif'))
                env.print_stat_readable()
                # import matplotlib.pyplot as plt
                # Q = plt.quiver(flow[1, 0, ::10, ::10], flow[1, 1, ::10, ::10])
                # plt.gca().invert_yaxis()
                # plt.show()
                # ch = raw_input()
    env.save_eval()