Beispiel #1
0
def test(args):
    assert (args.resume is True)
    if cfg.MODEL.TEST.FINETUNE:
        assert (cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS is True)

    base_dir = args.save_dir
    logging_config(folder=base_dir, name="testing")
    save_cfg(dir_path=base_dir, source=cfg.MODEL)

    generator_net, loss_net, = construct_modules(args)

    if args.dataset == "test":
        pd_path = cfg.HKO_PD.RAINY_TEST
    elif args.dataset == "valid":
        pd_path = cfg.HKO_PD.RAINY_VALID
    else:
        raise NotImplementedError

    hko_benchmark(
        ctx=args.ctx[0],
        generator_net=generator_net,
        loss_net=loss_net,
        sample_num=1,
        save_dir=os.path.join(base_dir, "iter{}_{}_finetune{}".format(
            cfg.MODEL.LOAD_ITER + 1, args.dataset, cfg.MODEL.TEST.FINETUNE)),
        finetune=cfg.MODEL.TEST.FINETUNE,
        mode=cfg.MODEL.TEST.MODE,
        pd_path=pd_path)
Beispiel #2
0
def test(args):
    cfg.MODEL.TRAJRNN.SAVE_MID_RESULTS = False
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="testing")
    save_movingmnist_cfg(base_dir)
    batch_size = 4
    mnist_iter = MovingMNISTAdvancedIterator(
        distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM,
        initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER,
                                cfg.MOVINGMNIST.VELOCITY_UPPER),
        rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER,
                              cfg.MOVINGMNIST.ROTATION_UPPER),
        scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER,
                               cfg.MOVINGMNIST.SCALE_VARIATION_UPPER),
        illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER,
                                   cfg.MOVINGMNIST.ILLUMINATION_UPPER))
    mnist_iter.load(file=cfg.MOVINGMNIST.TEST_FILE)
    mnist_rnn = MovingMNISTFactory(batch_size=batch_size,
                                   in_seq_len=cfg.MODEL.IN_LEN,
                                   out_seq_len=cfg.MODEL.OUT_LEN)
    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=mnist_rnn,
            context=args.ctx)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    states = EncoderForecasterStates(factory=mnist_rnn, ctx=args.ctx[0])
    states.reset_all()
    # Begin to load the model if load_dir is not empty
    assert len(cfg.MODEL.LOAD_DIR) > 0
    load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR,
                                   load_iter=cfg.MODEL.LOAD_ITER,
                                   encoder_net=encoder_net,
                                   forecaster_net=forecaster_net)
    overall_mse = 0
    for iter_id in range(10000 // batch_size):
        frame_dat, _ = mnist_iter.sample(batch_size=batch_size,
                                         seqlen=cfg.MOVINGMNIST.IN_LEN +
                                         cfg.MOVINGMNIST.OUT_LEN,
                                         random=False)
        data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...],
                              ctx=args.ctx[0]) / 255.0
        target_nd = mx.nd.array(frame_dat[cfg.MOVINGMNIST.IN_LEN:(
            cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
                                ctx=args.ctx[0]) / 255.0
        pred_nd = mnist_get_prediction(data_nd=data_nd,
                                       states=states,
                                       encoder_net=encoder_net,
                                       forecaster_net=forecaster_net)
        overall_mse += mx.nd.mean(mx.nd.square(pred_nd - target_nd)).asscalar()
        print(iter_id, overall_mse / (iter_id + 1))
    avg_mse = overall_mse / (10000 // batch_size)
    with open(os.path.join(base_dir, 'result.txt'), 'w') as f:
        f.write(str(avg_mse))
    print(base_dir, avg_mse)
Beispiel #3
0
def analysis(args):
    cfg.MODEL.TRAJRNN.SAVE_MID_RESULTS = True
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="testing")
    save_movingmnist_cfg(base_dir)
    mnist_iter = MovingMNISTAdvancedIterator(
        distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM,
        initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER,
                                cfg.MOVINGMNIST.VELOCITY_UPPER),
        rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER,
                              cfg.MOVINGMNIST.ROTATION_UPPER),
        scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER,
                               cfg.MOVINGMNIST.SCALE_VARIATION_UPPER),
        illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER,
                                   cfg.MOVINGMNIST.ILLUMINATION_UPPER))
    mnist_rnn = MovingMNISTFactory(batch_size=1,
                                   in_seq_len=cfg.MODEL.IN_LEN,
                                   out_seq_len=cfg.MODEL.OUT_LEN)
    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=mnist_rnn,
            context=args.ctx)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    states = EncoderForecasterStates(factory=mnist_rnn, ctx=args.ctx[0])
    states.reset_all()
    # Begin to load the model if load_dir is not empty
    assert len(cfg.MODEL.LOAD_DIR) > 0
    load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR,
                                   load_iter=cfg.MODEL.LOAD_ITER,
                                   encoder_net=encoder_net,
                                   forecaster_net=forecaster_net)
    for iter_id in range(1):
        frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                         seqlen=cfg.MOVINGMNIST.IN_LEN +
                                         cfg.MOVINGMNIST.OUT_LEN)
        data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...],
                              ctx=args.ctx[0]) / 255.0
        target_nd = mx.nd.array(frame_dat[cfg.MOVINGMNIST.IN_LEN:(
            cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
                                ctx=args.ctx[0]) / 255.0
        pred_nd = mnist_get_prediction(data_nd=data_nd,
                                       states=states,
                                       encoder_net=encoder_net,
                                       forecaster_net=forecaster_net)
        save_gif(pred_nd.asnumpy()[:, 0, 0, :, :],
                 os.path.join(base_dir, "pred.gif"))
        save_gif(data_nd.asnumpy()[:, 0, 0, :, :],
                 os.path.join(base_dir, "in.gif"))
        save_gif(target_nd.asnumpy()[:, 0, 0, :, :],
                 os.path.join(base_dir, "gt.gif"))
Beispiel #4
0
def train(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="training")
    save_icdm_cfg(base_dir)
    icdm_iter = ICDMIterator(root_dir=r'C:\Users\jing\projects\nowcasting\HKO-7\icdm_data\SRAD2018_TRAIN_010')
    icdm_factory = ICDMFactory(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
                                   in_seq_len=cfg.ICDM.IN_LEN,
                                   out_seq_len=cfg.ICDM.OUT_LEN)

    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=icdm_factory,
            context=args.ctx)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # Begin to load the model if load_dir is not empty
#    if len(cfg.MODEL.LOAD_DIR) > 0:
#        #load_mnist_params(load_dir=cfg.MODEL.LOAD_DIR, load_iter=cfg.MODEL.LOAD_ITER,
#                          encoder_net=encoder_net, forecaster_net=forecaster_net)
    states = EncoderForecasterStates(factory=icdm_factory, ctx=args.ctx[0])
    states.reset_all()
    for info in icdm_factory.init_encoder_state_info:
        assert info["__layout__"].find('N') == 0, "Layout=%s is not supported!" %info["__layout__"]
    iter_id = 0
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        input_batch, output_batch = icdm_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                         in_len=cfg.ICDM.IN_LEN, out_len=cfg.ICDM.OUT_LEN)
        # data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...], ctx=args.ctx[0]) / 255.0
        data_nd = mx.nd.array(input_batch, ctx=args.ctx[0], dtype=np.float32)
        target_nd = mx.nd.array(output_batch, ctx=args.ctx[0], dtype=np.float32)
        # target_nd = mx.nd.array(
        #     frame_dat[cfg.MODEL.IN_LEN:(cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
        #     ctx=args.ctx[0]) / 255.0
        train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                   encoder_net=encoder_net, forecaster_net=forecaster_net,
                   loss_net=loss_net, init_states=states,
                   data_nd=data_nd, gt_nd=target_nd, mask_nd=None,
                   iter_id=iter_id)
        print('step:', iter_id)
        if (iter_id+ 1) % 100 == 0:
            pass
        if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:
            encoder_net.save_checkpoint(
                prefix=os.path.join(base_dir, "encoder_net"),
                epoch=iter_id)
            forecaster_net.save_checkpoint(
                prefix=os.path.join(base_dir, "forecaster_net"),
                epoch=iter_id)
        iter_id += 1
def run(pd_path=cfg.HKO_PD.RAINY_TEST, mode="fixed"):
    base_dir = os.path.join('hko7_benchmark', 'last_frame')
    logging_config(base_dir)
    batch_size = 1
    env = HKOBenchmarkEnv(pd_path=pd_path, save_dir=base_dir, mode=mode)
    while not env.done:
        in_frame_dat, in_datetime_clips, out_datetime_clips,\
        begin_new_episode, need_upload_prediction =\
            env.get_observation(batch_size=batch_size)
        if need_upload_prediction:
            prediction = np.zeros(shape=(cfg.HKO.BENCHMARK.OUT_LEN, ) +
                                  in_frame_dat.shape[1:],
                                  dtype=in_frame_dat.dtype)
            prediction[:] = in_frame_dat[-1, ...]
            env.upload_prediction(prediction=prediction)
            env.print_stat_readable()
    env.save_eval()
Beispiel #6
0
def test(args):
    base_dir = get_base_dir(args)
    logging_config(folder=base_dir, name='testing')

    # Get modules
    generator_net, loss_net = construct_modules(args)

    # Prepare data
    mnist_iter = MovingMNISTAdvancedIterator(
        distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM,
        initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER,
                                cfg.MOVINGMNIST.VELOCITY_UPPER),
        rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER,
                              cfg.MOVINGMNIST.ROTATION_UPPER),
        scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER,
                               cfg.MOVINGMNIST.SCALE_VARIATION_UPPER),
        illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER,
                                   cfg.MOVINGMNIST.ILLUMINATION_UPPER))
    num_samples, seqlen = mnist_iter.load(file=cfg.MOVINGMNIST.TEST_FILE)

    overall_mse = 0
    for iter_id in range(num_samples // cfg.MODEL.TRAIN.BATCH_SIZE):
        frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                         seqlen=seqlen,
                                         random=False)

        context_nd = mx.nd.array(frame_dat[:cfg.MOVINGMNIST.IN_LEN],
                                 ctx=args.ctx[0]) / 255.0
        gt_nd = mx.nd.array(frame_dat[cfg.MOVINGMNIST.IN_LEN:],
                            ctx=args.ctx[0]) / 255.0

        # Transform data to NCDHW shape needed for 3D Convolution encoder
        context_nd = mx.nd.transpose(context_nd, axes=(1, 2, 0, 3, 4))
        gt_nd = mx.nd.transpose(gt_nd, axes=(1, 2, 0, 3, 4))

        pred_nd = test_step(generator_net, context_nd)
        overall_mse += mx.nd.mean(mx.nd.square(pred_nd - gt_nd)).asscalar()
        print(iter_id, overall_mse / (iter_id + 1))

    avg_mse = overall_mse / (num_samples // cfg.MODEL.TRAIN.BATCH_SIZE)
    with open(os.path.join(base_dir, 'result.txt'), 'w') as f:
        f.write(str(avg_mse))
    print(base_dir, avg_mse)
Beispiel #7
0
def test_sst(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    assert len(cfg.MODEL.LOAD_DIR) > 0
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="testing")
    save_cfg(dir_path=base_dir, source=cfg.MODEL)
    sst_nowcasting_online = SSTNowcastingFactory(batch_size=1,
                                                 in_seq_len=cfg.MODEL.IN_LEN,
                                                 out_seq_len=cfg.MODEL.OUT_LEN)
    t_encoder_net, t_forecaster_net, t_loss_net = encoder_forecaster_build_networks(
        factory=sst_nowcasting_online, context=args.ctx, for_finetune=True)
    t_encoder_net.summary()
    t_forecaster_net.summary()
    t_loss_net.summary()
    load_encoder_forecaster_params(
        load_dir=cfg.MODEL.LOAD_DIR,
        load_iter=cfg.MODEL.LOAD_ITER,
        encoder_net=t_encoder_net,
        forecaster_net=t_forecaster_net,
    )
    if args.dataset == "test":
        pd_path = cfg.SST_PD.RAINY_TEST
    elif args.dataset == "valid":
        pd_path = cfg.SST_PD.RAINY_VALID
    else:
        raise NotImplementedError
    run_benchmark(
        sst_factory=sst_nowcasting_online,
        context=args.ctx[0],
        encoder_net=t_encoder_net,
        forecaster_net=t_forecaster_net,
        loss_net=t_loss_net,
        save_dir=os.path.join(
            base_dir,
            "iter%d_%s_finetune%d" %
            (cfg.MODEL.LOAD_ITER + 1, args.dataset, cfg.MODEL.TEST.FINETUNE),
        ),
        finetune=cfg.MODEL.TEST.FINETUNE,
        mode=cfg.MODEL.TEST.MODE,
        pd_path=pd_path,
    )
Beispiel #8
0
def argument_parser():
    parser = argparse.ArgumentParser(
        description='Deconvolution baseline for HKO')

    cfg.DATASET = "HKO"

    mode_args(parser)
    training_args(parser)
    dataset_args(parser)
    model_args(parser)

    args = parser.parse_args()

    parse_mode_args(args)
    parse_training_args(args)
    parse_model_args(args)

    base_dir = get_base_dir(args)
    logging_config(folder=base_dir, name="training")
    save_cfg(base_dir, source=cfg.MODEL)

    logging.info(args)
    return args
Beispiel #9
0
def train(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = get_base_dir(args)
    logging_config(folder=base_dir)
    save_cfg(dir_path=base_dir, source=cfg.MODEL)
    if cfg.MODEL.TRAIN.TBPTT:
        # Create a set of sequent iterators with different starting point
        train_hko_iters = []
        train_hko_iter_restart = []
        for _ in range(cfg.MODEL.TRAIN.BATCH_SIZE):
            ele_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TRAIN,
                                   sample_mode="sequent",
                                   seq_len=cfg.MODEL.IN_LEN +
                                   cfg.MODEL.OUT_LEN,
                                   stride=cfg.MODEL.IN_LEN)
            ele_iter.random_reset()
            train_hko_iter_restart.append(True)
            train_hko_iters.append(ele_iter)
    else:
        train_hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TRAIN,
                                     sample_mode="random",
                                     seq_len=cfg.MODEL.IN_LEN +
                                     cfg.MODEL.OUT_LEN)

    hko_nowcasting = HKONowcastingFactory(
        batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
        ctx_num=len(args.ctx),
        in_seq_len=cfg.MODEL.IN_LEN,
        out_seq_len=cfg.MODEL.OUT_LEN)
    hko_nowcasting_online = HKONowcastingFactory(batch_size=1,
                                                 in_seq_len=cfg.MODEL.IN_LEN,
                                                 out_seq_len=cfg.MODEL.OUT_LEN)
    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=hko_nowcasting,
            context=args.ctx)
    t_encoder_net, t_forecaster_net, t_loss_net = \
        encoder_forecaster_build_networks(
            factory=hko_nowcasting_online,
            context=args.ctx[0],
            shared_encoder_net=encoder_net,
            shared_forecaster_net=forecaster_net,
            shared_loss_net=loss_net,
            for_finetune=True)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # Begin to load the model if load_dir is not empty
    if len(cfg.MODEL.LOAD_DIR) > 0:
        load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR,
                                       load_iter=cfg.MODEL.LOAD_ITER,
                                       encoder_net=encoder_net,
                                       forecaster_net=forecaster_net)
    states = EncoderForecasterStates(factory=hko_nowcasting, ctx=args.ctx[0])
    for info in hko_nowcasting.init_encoder_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    for info in hko_nowcasting.init_forecaster_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    test_mode = "online" if cfg.MODEL.TRAIN.TBPTT else "fixed"
    iter_id = 0
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        if not cfg.MODEL.TRAIN.TBPTT:
            # We are not using TBPTT, we could directly sample a random minibatch
            frame_dat, mask_dat, datetime_clips, _ = \
                train_hko_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE)
            states.reset_all()
        else:
            # We are using TBPTT, we should sample minibatches from the iterators.
            frame_dat_l = []
            mask_dat_l = []
            for i, ele_iter in enumerate(train_hko_iters):
                if ele_iter.use_up:
                    states.reset_batch(batch_id=i)
                    ele_iter.random_reset()
                    train_hko_iter_restart[i] = True
                if train_hko_iter_restart[
                        i] == False and ele_iter.check_new_start():
                    states.reset_batch(batch_id=i)
                    ele_iter.random_reset()
                frame_dat, mask_dat, datetime_clips, _ = \
                    ele_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE)
                train_hko_iter_restart[i] = False
                frame_dat_l.append(frame_dat)
                mask_dat_l.append(mask_dat)
            frame_dat = np.concatenate(frame_dat_l, axis=1)
            mask_dat = np.concatenate(mask_dat_l, axis=1)
        data_nd = mx.nd.array(frame_dat[0:cfg.MODEL.IN_LEN, ...],
                              ctx=args.ctx[0]) / 255.0
        target_nd = mx.nd.array(frame_dat[cfg.MODEL.IN_LEN:(
            cfg.MODEL.IN_LEN + cfg.MODEL.OUT_LEN), ...],
                                ctx=args.ctx[0]) / 255.0
        mask_nd = mx.nd.array(mask_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN +
                                                         cfg.MODEL.OUT_LEN),
                                       ...],
                              ctx=args.ctx[0])
        states, _ = train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                               encoder_net=encoder_net,
                               forecaster_net=forecaster_net,
                               loss_net=loss_net,
                               init_states=states,
                               data_nd=data_nd,
                               gt_nd=target_nd,
                               mask_nd=mask_nd,
                               iter_id=iter_id)
        if (iter_id + 1) % cfg.MODEL.VALID_ITER == 0:
            run_benchmark(hko_factory=hko_nowcasting_online,
                          context=args.ctx[0],
                          encoder_net=t_encoder_net,
                          forecaster_net=t_forecaster_net,
                          loss_net=t_loss_net,
                          save_dir=os.path.join(base_dir, "iter%d_valid" %
                                                (iter_id + 1)),
                          mode=test_mode,
                          pd_path=cfg.HKO_PD.RAINY_VALID)
        if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:
            encoder_net.save_checkpoint(prefix=os.path.join(
                base_dir, "encoder_net"),
                                        epoch=iter_id)
            forecaster_net.save_checkpoint(prefix=os.path.join(
                base_dir, "forecaster_net"),
                                           epoch=iter_id)
        iter_id += 1
Beispiel #10
0
def run(pd_path=cfg.HKO_PD.RAINY_TEST,
        mode="fixed",
        interp_type="bilinear",
        nonlinear_transform=True):
    transformer = NonLinearRoverTransform()
    flow_factory = VarFlowFactory(max_level=6,
                                  start_level=0,
                                  n1=2,
                                  n2=2,
                                  rho=1.5,
                                  alpha=2000,
                                  sigma=4.5)
    assert interp_type == "bilinear", "Nearest interpolation is implemented in CPU and is too slow." \
                                      " We only support bilinear interpolation for rover."
    if nonlinear_transform:
        base_dir = os.path.join('hko7_benchmark', 'rover-nonlinear')
    else:
        base_dir = os.path.join('hko7_benchmark', 'rover-linear')
    logging_config(base_dir)
    batch_size = 1
    env = HKOBenchmarkEnv(pd_path=pd_path, save_dir=base_dir, mode=mode)
    counter = 0
    while not env.done:
        in_frame_dat, in_datetime_clips, out_datetime_clips, \
        begin_new_episode, need_upload_prediction = \
            env.get_observation(batch_size=batch_size)
        if need_upload_prediction:
            counter += 1
            prediction = np.zeros(shape=(cfg.HKO.BENCHMARK.OUT_LEN, ) +
                                  in_frame_dat.shape[1:],
                                  dtype=np.float32)
            I1 = in_frame_dat[-2, :, 0, :, :]
            I2 = in_frame_dat[-1, :, 0, :, :]
            mask_I1 = precompute_mask(I1)
            mask_I2 = precompute_mask(I2)
            I1 = I1 * mask_I1
            I2 = I2 * mask_I2
            if nonlinear_transform:
                I1 = transformer.transform(I1)
                I2 = transformer.transform(I2)
            flow = flow_factory.batch_calc_flow(I1=I1, I2=I2)
            if interp_type == "bilinear":
                init_im = nd.array(I2.reshape(
                    (I2.shape[0], 1, I2.shape[1], I2.shape[2])),
                                   ctx=mx.gpu())
                nd_flow = nd.array(np.concatenate(
                    (flow[:, :1, :, :], -flow[:, 1:, :, :]), axis=1),
                                   ctx=mx.gpu())
                nd_pred_im = nd.zeros(shape=prediction.shape)
                for i in range(cfg.HKO.BENCHMARK.OUT_LEN):
                    new_im = nd_advection(init_im, flow=nd_flow)
                    nd_pred_im[i][:] = new_im
                    init_im[:] = new_im
                prediction = nd_pred_im.asnumpy()
            elif interp_type == "nearest":
                init_im = I2.reshape(
                    (I2.shape[0], 1, I2.shape[1], I2.shape[2]))
                for i in range(cfg.HKO.BENCHMARK.OUT_LEN):
                    new_im = nearest_neighbor_advection(init_im, flow)
                    prediction[i, ...] = new_im
                    init_im = new_im
            if nonlinear_transform:
                prediction = transformer.rev_transform(prediction)
            env.upload_prediction(prediction=prediction)
            if counter % 10 == 0:
                save_hko_gif(in_frame_dat[:, 0, 0, :, :],
                             save_path=os.path.join(base_dir, 'in.gif'))
                save_hko_gif(prediction[:, 0, 0, :, :],
                             save_path=os.path.join(base_dir, 'pred.gif'))
                env.print_stat_readable()
                # import matplotlib.pyplot as plt
                # Q = plt.quiver(flow[1, 0, ::10, ::10], flow[1, 1, ::10, ::10])
                # plt.gca().invert_yaxis()
                # plt.show()
                # ch = raw_input()
    env.save_eval()
Beispiel #11
0
def train(args):
    base_dir = get_base_dir(args)
    logging_config(folder=base_dir)
    save_cfg(dir_path=base_dir, source=cfg.MODEL)

    use_gan = (cfg.MODEL.PROBLEM_FORM == 'regression') and (cfg.MODEL.GAN_G_LAMBDA > 0.0)
    if cfg.MODEL.FRAME_SKIP_OUT > 1:
        # here assume the target span the last 30 frames 
        iter_outlen = cfg.MODEL.OUT_LEN
    else:
        iter_outlen = 30
    model_outlen = cfg.MODEL.OUT_LEN  # training on 30 costs too much memory

    train_szo_iter = SZOIterator(rec_paths=cfg.SZO_TRAIN_DATA_PATHS,
                                in_len=cfg.MODEL.IN_LEN,
                                out_len=iter_outlen,  # iterator has to provide full output sequence as target if needed
                                batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                frame_skip_in=cfg.MODEL.FRAME_SKIP_IN,
                                frame_skip_out=cfg.MODEL.FRAME_SKIP_OUT, 
                                ctx=args.ctx)
    valid_szo_iter = SZOIterator(rec_paths=cfg.SZO_TEST_DATA_PATHS,
                                in_len=cfg.MODEL.IN_LEN,
                                out_len=iter_outlen,
                                batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                frame_skip_in=cfg.MODEL.FRAME_SKIP_IN,
                                frame_skip_out=cfg.MODEL.FRAME_SKIP_OUT,
                                ctx=args.ctx)
        
    szo_nowcasting = SZONowcastingFactory(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
                                          ctx_num=len(args.ctx),
                                          in_seq_len=cfg.MODEL.IN_LEN,
                                          out_seq_len=model_outlen,  # model still generate cfg.MODEL.OUT_LEN number of outputs at a time
                                          frame_stack=cfg.MODEL.FRAME_STACK)
    # discrim_net, loss_D_net will be None if use_gan = False
    encoder_net, forecaster_net, loss_net, discrim_net, loss_D_net = \
        encoder_forecaster_build_networks(
            factory=szo_nowcasting,
            context=args.ctx)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    if use_gan:
        discrim_net.summary()
        loss_D_net.summary()
    if use_gan:
        loss_types = ('mse','gdl','gan','dis')
    else:
        if cfg.MODEL.PROBLEM_FORM == 'regression':
            loss_types = ('mse', 'gdl')
        elif cfg.MODEL.PROBLEM_FORM == 'classification':
            loss_types = ('ce',)
        else:
            raise NotImplementedError
    # try to load checkpoint
    if args.resume:
        start_iter_id = latest_iter_id(base_dir)
        encoder_net.load_params(os.path.join(base_dir, 'encoder_net'+'-%04d.params'%(start_iter_id)))
        forecaster_net.load_params(os.path.join(base_dir, 'forecaster_net'+'-%04d.params'%(start_iter_id)))
        synchronize_kvstore(encoder_net)
        synchronize_kvstore(forecaster_net)
        if not args.resume_param_only:
            encoder_net.load_optimizer_states(os.path.join(base_dir, 'encoder_net'+'-%04d.states'%(start_iter_id)))
            forecaster_net.load_optimizer_states(os.path.join(base_dir, 'forecaster_net'+'-%04d.states'%(start_iter_id)))
        if use_gan:
            discrim_net.load_params(os.path.join(base_dir, 'discrim_net'+'-%04d.params'%(start_iter_id)))
            synchronize_kvstore(discrim_net)
            if not args.resume_param_only:
                discrim_net.load_optimizer_states(os.path.join(base_dir, 'discrim_net'+'-%04d.states'%(start_iter_id)))
            synchronize_kvstore(discrim_net)
    else:
        start_iter_id = -1

    if args.resume and (not args.resume_param_only):
        with open(os.path.join(base_dir, 'train_loss_dicts.pkl'), 'rb') as f:
            train_loss_dicts = pickle.load(f)
        with open(os.path.join(base_dir, 'valid_loss_dicts.pkl'), 'rb') as f:
            valid_loss_dicts = pickle.load(f)
        for dicts in (train_loss_dicts, valid_loss_dicts):
            keys_to_delete = []
            keys_to_add = []
            key_len = 0
            for k in dicts.keys():
                key_len = len(dicts[k])
                if k not in loss_types:
                    keys_to_delete.append(k)
            for k in keys_to_delete:
                del dicts[k]
            for k in loss_types:
                if k not in dicts.keys():
                    dicts[k] = [0] * key_len
    else:
        train_loss_dicts = {}
        valid_loss_dicts = {}
        for dicts in (train_loss_dicts, valid_loss_dicts):
            for typ in loss_types:
                dicts[typ] = []

    states = EncoderForecasterStates(factory=szo_nowcasting, ctx=args.ctx[0])
    for info in szo_nowcasting.init_encoder_state_info:
        assert info["__layout__"].find('N') == 0, "Layout=%s is not supported!" %info["__layout__"]
    for info in szo_nowcasting.init_forecaster_state_info:
        assert info["__layout__"].find('N') == 0, "Layout=%s is not supported!" % info["__layout__"]

    cumulative_loss = {}
    for k in train_loss_dicts.keys():
        cumulative_loss[k] = 0.0

    iter_id = start_iter_id + 1
    fix_shift = True if cfg.MODEL.OPTFLOW_AS_INPUT  else False
    chan = cfg.SZO.DATA.IMAGE_CHANNEL if cfg.MODEL.OPTFLOW_AS_INPUT else 0
    buffers = {}
    buffers['fake'] = deque([], maxlen=cfg.MODEL.TRAIN.GEN_BUFFER_LEN)
    buffers['true'] = deque([], maxlen=cfg.MODEL.TRAIN.GEN_BUFFER_LEN)
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        frame_dat = train_szo_iter.sample(fix_shift=fix_shift)
        data_nd = frame_dat[0:cfg.MODEL.IN_LEN,:,:,:,:] / 255.0  # scale to [0,1]
        # only take the channel of grey scale
        target_nd = frame_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN + cfg.MODEL.OUT_LEN),:,chan,:,:] / 255.0
        target_nd = target_nd.expand_dims(axis=2)
        states.reset_all()
        if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:
            mask_nd = target_nd < 1.0
        else:
            mask_nd = mx.nd.ones_like(target_nd)
        
        states, loss_dict, pred_nd, buffers = train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                               encoder_net=encoder_net, forecaster_net=forecaster_net,
                               loss_net=loss_net, discrim_net=discrim_net, 
                               loss_D_net=loss_D_net, init_states=states,
                               data_nd=data_nd, gt_nd=target_nd, mask_nd=mask_nd,
                               factory=szo_nowcasting, iter_id=iter_id, buffers=buffers)
        for k in cumulative_loss.keys():
            loss = loss_dict[k+'_output']
            cumulative_loss[k] += loss

        if (iter_id+1) % cfg.MODEL.VALID_ITER == 0:
            for i in range(cfg.MODEL.VALID_LOOP):
                states.reset_all()            
                frame_dat_v = valid_szo_iter.sample(fix_shift=fix_shift)
                data_nd_v = frame_dat_v[0:cfg.MODEL.IN_LEN,:,:,:,:] / 255.0
                gt_nd_v = frame_dat_v[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN+cfg.MODEL.OUT_LEN),:,chan,:,:] / 255.0
                gt_nd_v = gt_nd_v.expand_dims(axis=2)
                if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:
                    mask_nd_v = gt_nd_v < 1.0
                else:
                    mask_nd_v = mx.nd.ones_like(gt_nd_v)
                states, new_valid_loss_dicts, pred_nd_v = valid_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                encoder_net=encoder_net, forecaster_net=forecaster_net,
                                loss_net=loss_net, discrim_net=discrim_net,
                                loss_D_net=loss_D_net, init_states=states,
                                data_nd=data_nd_v, gt_nd=gt_nd_v, mask_nd=mask_nd_v,
                                valid_loss_dicts=valid_loss_dicts, factory=szo_nowcasting, iter_id=iter_id)
                if i == 0:
                    for k in valid_loss_dicts.keys():
                        valid_loss_dicts[k].append(new_valid_loss_dicts[k])
                else:
                    for k in valid_loss_dicts.keys():
                        valid_loss_dicts[k][-1] += new_valid_loss_dicts[k]
                
            for k in valid_loss_dicts.keys():
                valid_loss_dicts[k][-1] /= cfg.MODEL.VALID_LOOP
                plot_loss_curve(os.path.join(base_dir, 'valid_'+k+'_loss'), valid_loss_dicts[k])
            
        if (iter_id+1) % cfg.MODEL.DRAW_EVERY == 0:
            for k in train_loss_dicts.keys():
                avg_loss = cumulative_loss[k] / cfg.MODEL.DRAW_EVERY
                if k=='gan':
                    avg_loss = avg_loss if avg_loss < 1.0 else 1.0
                elif k == 'dis':
                    if avg_loss > 1.0:
                        avg_loss = 1.0
                if avg_loss < 0:
                    avg_loss = 0
                train_loss_dicts[k].append(avg_loss)
                cumulative_loss[k] = 0.0
                plot_loss_curve(os.path.join(base_dir, 'train_'+k+'_loss'), train_loss_dicts[k])

        if (iter_id+1) % cfg.MODEL.DISPLAY_EVERY == 0:
            new_frame_dat = train_szo_iter.sample(fix_shift=fix_shift)
            data_nd_d = new_frame_dat[0:cfg.MODEL.IN_LEN,:,:,:,:] / 255.0
            target_nd_d = new_frame_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN + cfg.MODEL.OUT_LEN),:,chan,:,:] / 255.0
            target_nd_d = target_nd_d.expand_dims(axis=2)
            states.reset_all()
            pred_nd_d = get_prediction(data_nd_d, states, encoder_net, forecaster_net)
            if cfg.MODEL.PROBLEM_FORM == 'classification':
                pred_nd_d = from_class_to_image(pred_nd_d)
            display_path1 = os.path.join(base_dir, 'display_'+str(iter_id))
            display_path2 = os.path.join(base_dir, 'display_'+str(iter_id)+'_')
            if not os.path.exists(display_path1):
                os.mkdir(display_path1)
            if not os.path.exists(display_path2):
                os.mkdir(display_path2)

            data_nd_d = (data_nd_d*255.0).clip(0, 255.0)
            target_nd_d = (target_nd_d*255.0).clip(0, 255.0)
            pred_nd_d = (pred_nd_d*255.0).clip(0, 255.0)
            save_prediction(data_nd_d[:,0,chan,:,:], target_nd_d[:,0,0,:,:], pred_nd_d[:,0,0,:,:], display_path1, default_as_0=True)
            save_prediction(data_nd_d[:,0,chan,:,:], target_nd_d[:,0,0,:,:], pred_nd_d[:,0,0,:,:], display_path2, default_as_0=False)
        if (iter_id+1) % cfg.MODEL.SAVE_ITER == 0:
            encoder_net.save_checkpoint(
                prefix=os.path.join(base_dir, "encoder_net",),
                epoch=iter_id,
                save_optimizer_states=True)
            forecaster_net.save_checkpoint(
                prefix=os.path.join(base_dir, "forecaster_net",),
                epoch=iter_id,
                save_optimizer_states=True)
            if cfg.MODEL.GAN_G_LAMBDA > 0:
                discrim_net.save_checkpoint(
                    prefix=os.path.join(base_dir, "discrim_net",),
                    epoch=iter_id,
                    save_optimizer_states=True)
            path1 = os.path.join(base_dir, 'train_loss_dicts.pkl')
            path2 = os.path.join(base_dir, 'valid_loss_dicts.pkl')
            with open(path1, 'wb') as f:
                pickle.dump(train_loss_dicts, f)
            with open(path2, 'wb') as f:
                pickle.dump(valid_loss_dicts, f)
        iter_id += 1
Beispiel #12
0
def main(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="train")
    save_cfg(dir_path=base_dir, source=cfg.MODEL)
    metadata_file = os.path.join(
        args.data_dir,
        'hdf_metadata.csv') if args.data_csv is None else args.data_csv

    all_data = h5py.File(os.path.join(args.data_dir, 'hdf_archives',
                                      'all_data.hdf5'),
                         'r',
                         libver='latest')
    outlier_mask = cv2.imread(os.path.join(args.data_dir, 'mask.png'), 0)

    metadata = pd.read_csv(metadata_file, index_col='id')
    metadata['start_datetime'] = pd.to_datetime(metadata['start_datetime'])
    metadata['end_datetime'] = pd.to_datetime(metadata['end_datetime'])
    if args.date_start is not None:
        metadata = metadata.loc[metadata['start_datetime'] >= args.date_start]
    if args.date_end is not None:
        metadata = metadata.loc[metadata['start_datetime'] < args.date_end]
    sort_meta = metadata.sample(frac=1)
    split_idx = int(len(sort_meta) * 0.95)
    train_meta = sort_meta.iloc[:split_idx]
    test_meta = sort_meta.iloc[split_idx:]

    logging.info("Initializing data iterator with filter threshold %s" %
                 cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD)
    train_model_iter = infinite_batcher(
        all_data,
        train_meta,
        outlier_mask,
        shuffle=False,
        filter_threshold=cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD)

    model_nowcasting = RadarNowcastingFactory(
        batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
        ctx_num=len(args.ctx),
        in_seq_len=cfg.MODEL.IN_LEN,
        out_seq_len=cfg.MODEL.OUT_LEN)
    model_nowcasting_online = RadarNowcastingFactory(
        batch_size=1,
        in_seq_len=cfg.MODEL.IN_LEN,
        out_seq_len=cfg.MODEL.OUT_LEN)
    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=model_nowcasting,
            context=args.ctx)
    t_encoder_net, t_forecaster_net, t_loss_net = \
        encoder_forecaster_build_networks(
            factory=model_nowcasting_online,
            context=args.ctx[0],
            shared_encoder_net=encoder_net,
            shared_forecaster_net=forecaster_net,
            shared_loss_net=loss_net,
            for_finetune=True)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # Begin to load the model if load_dir is not empty
    if len(cfg.MODEL.LOAD_DIR) > 0:
        load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR,
                                       load_iter=cfg.MODEL.LOAD_ITER,
                                       encoder_net=encoder_net,
                                       forecaster_net=forecaster_net)
    states = EncoderForecasterStates(factory=model_nowcasting, ctx=args.ctx[0])
    for info in model_nowcasting.init_encoder_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    for info in model_nowcasting.init_forecaster_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    test_mode = "online" if cfg.MODEL.TRAIN.TBPTT else "fixed"
    iter_id = 0
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        # sample a random minibatch
        try:
            frame_dat, _, mask_dat = next(train_model_iter)
        except StopIteration:
            break
        else:
            states.reset_all()
            data_nd = mx.nd.array(frame_dat[0:cfg.MODEL.IN_LEN, ...],
                                  ctx=args.ctx[0])
            target_nd = mx.nd.array(
                frame_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN +
                                            cfg.MODEL.OUT_LEN), ...],
                ctx=args.ctx[0])
            mask_nd = mx.nd.array(
                mask_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN +
                                           cfg.MODEL.OUT_LEN), ...],
                ctx=args.ctx[0])
            states, _ = train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                   encoder_net=encoder_net,
                                   forecaster_net=forecaster_net,
                                   loss_net=loss_net,
                                   init_states=states,
                                   data_nd=data_nd,
                                   gt_nd=target_nd,
                                   mask_nd=mask_nd,
                                   iter_id=iter_id)
            if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:
                encoder_net.save_checkpoint(prefix=os.path.join(
                    base_dir, "encoder_net"),
                                            epoch=iter_id)
                forecaster_net.save_checkpoint(prefix=os.path.join(
                    base_dir, "forecaster_net"),
                                               epoch=iter_id)
            if (iter_id + 1) % cfg.MODEL.VALID_ITER == 0:
                test_model_iter = HDFIterator(
                    all_data,
                    test_meta,
                    outlier_mask,
                    batch_size=1,
                    shuffle=False,
                    filter_threshold=cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD
                )
                run_benchmark(model_factory=model_nowcasting_online,
                              context=args.ctx[0],
                              encoder_net=t_encoder_net,
                              forecaster_net=t_forecaster_net,
                              save_dir=os.path.join(
                                  base_dir, "iter%d_valid" % (iter_id + 1)),
                              mode=test_mode,
                              batcher=test_model_iter)
            iter_id += 1
Beispiel #13
0
def main(args):
    save_dir = args.save_dir
    batch_size = args.batch_size
    logging_config(folder=save_dir, name="predict")
    predictor = NowcastingPredictor(args.model_dir, args.model_iter,
                                    args.model_cfg, batch_size, args.ctx)
    pred_saver = DateWriter(save_dir)
    in_saver = DateWriter(save_dir, fname_format="%Y%m%d%H%M_in_{sl}.npz")
    gt_saver = DateWriter(save_dir, fname_format="%Y%m%d%H%M_gt_{sl}.npz")
    metadata_file = os.path.join(
        args.data_dir,
        'hdf_metadata.csv') if args.data_csv is None else args.data_csv

    metadata = pd.read_csv(metadata_file, index_col='id')
    metadata['start_datetime'] = pd.to_datetime(metadata['start_datetime'])
    metadata['end_datetime'] = pd.to_datetime(metadata['end_datetime'])
    if args.date_start is not None:
        metadata = metadata.loc[metadata['start_datetime'] >= args.date_start]
    if args.date_end is not None:
        metadata = metadata.loc[metadata['start_datetime'] < args.date_end]
    all_data = h5py.File(os.path.join(args.data_dir, 'hdf_archives',
                                      'all_data.hdf5'),
                         'r',
                         libver='latest')
    outlier_mask = cv2.imread(os.path.join(args.data_dir, 'mask.png'), 0)

    radar_mask = radar_circle_mask()
    batcher = HDFIterator(all_data,
                          metadata,
                          outlier_mask,
                          batch_size=batch_size,
                          shuffle=False,
                          filter_threshold=0,
                          sort_by='id',
                          ascending=True,
                          return_mask=False)

    mask_out = np.tile(radar_mask, (cfg.MODEL.OUT_LEN, batch_size, 1, 1))
    mask_in = np.tile(radar_mask, (cfg.MODEL.IN_LEN, batch_size, 1, 1))
    # index_df = pd.DataFrame(columns=['filename', 'index', 'timestamp'])

    j = 0
    while True:
        try:
            frame_dat, datetime_clip = next(batcher)
            # datetime_clip = datetime_clip[0]
            logging.info("Iteration {}: [{}] {}x{} clips".format(
                j, ", ".join([str(x) for x in datetime_clip]),
                len(datetime_clip), len(frame_dat)))
        except StopIteration:
            break

        # (seq_len, bs, ch, h, w)
        in_frame = frame_dat[:cfg.MODEL.IN_LEN, ...]
        out_frame = frame_dat[cfg.MODEL.IN_LEN:, ...]
        pred_frame = predictor.predict(in_frame)

        pred_frame[pred_frame < 0.001] = 0
        out_frame[out_frame < 0.001] = 0
        in_frame[in_frame < 0.001] = 0

        # (seq_len, bs, h, w)
        pred_frame = np.around(np.squeeze(pred_frame),
                               decimals=3).astype(np.float32)
        out_frame = np.around(np.squeeze(out_frame),
                              decimals=3).astype(np.float32)
        in_frame = np.around(np.squeeze(in_frame),
                             decimals=3).astype(np.float32)

        pred_frame[~mask_out] = np.nan
        out_frame[~mask_out] = np.nan
        in_frame[~mask_in] = np.nan

        # (bs, seq_len, h, w)
        pred_frame = pred_frame.swapaxes(0, 1)
        out_frame = out_frame.swapaxes(0, 1)
        in_frame = in_frame.swapaxes(0, 1)

        for i, dc in enumerate(datetime_clip):
            dtclip = dc + timedelta(minutes=5 * (cfg.MODEL.IN_LEN - 1))
            pred_saver.push(pred_frame[i], dtclip)
            in_saver.push(in_frame[i], dtclip)
            gt_saver.push(out_frame[i], dtclip)
            # index_df.loc[j+i] = {'filename': "{:06d}".format(chunk), 'index': (j + i) % split, 'timestamp': dtclip}

        j += batch_size

    pred_saver.close()
    in_saver.close()
    gt_saver.close()
Beispiel #14
0
def train(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="training")
    save_movingmnist_cfg(base_dir)
    mnist_iter = MovingMNISTAdvancedIterator(
        distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM,
        initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER,
                                cfg.MOVINGMNIST.VELOCITY_UPPER),
        rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER,
                              cfg.MOVINGMNIST.ROTATION_UPPER),
        scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER,
                               cfg.MOVINGMNIST.SCALE_VARIATION_UPPER),
        illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER,
                                   cfg.MOVINGMNIST.ILLUMINATION_UPPER))

    mnist_rnn = MovingMNISTFactory(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE //
                                   len(args.ctx),
                                   in_seq_len=cfg.MODEL.IN_LEN,
                                   out_seq_len=cfg.MODEL.OUT_LEN)

    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=mnist_rnn,
            context=args.ctx)
    t_encoder_net, t_forecaster_net, t_loss_net = \
        encoder_forecaster_build_networks(
            factory=mnist_rnn,
            context=args.ctx[0],
            shared_encoder_net=encoder_net,
            shared_forecaster_net=forecaster_net,
            shared_loss_net=loss_net,
            for_finetune=True)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # Resume last checkpoint
    if args.resume:
        encoder_net.load(prefix=os.path.join(base_dir, 'encoder_net'),
                         epoch=latest_iter_id(base_dir),
                         load_optimizer_states=True,
                         data_names=[
                             'data', 'ebrnn1_begin_state_h',
                             'ebrnn2_begin_state_h', 'ebrnn3_begin_state_h'
                         ],
                         label_names=[])  # change it next time
        forecaster_net.load(prefix=os.path.join(base_dir, 'forecaster_net'),
                            epoch=latest_iter_id(base_dir),
                            load_optimizer_states=True,
                            data_names=[
                                'fbrnn1_begin_state_h', 'fbrnn2_begin_state_h',
                                'fbrnn3_begin_state_h'
                            ],
                            label_names=[])  # change it next time
    # Begin to load the model if load_dir is not empty
    if len(cfg.MODEL.LOAD_DIR) > 0:
        load_mnist_params(load_dir=cfg.MODEL.LOAD_DIR,
                          load_iter=cfg.MODEL.LOAD_ITER,
                          encoder_net=encoder_net,
                          forecaster_net=forecaster_net)
    states = EncoderForecasterStates(factory=mnist_rnn, ctx=args.ctx[0])
    states.reset_all()
    for info in mnist_rnn.init_encoder_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    iter_id = 0
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                         seqlen=cfg.MOVINGMNIST.IN_LEN +
                                         cfg.MOVINGMNIST.OUT_LEN)
        data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...],
                              ctx=args.ctx[0]) / 255.0
        target_nd = mx.nd.array(frame_dat[cfg.MODEL.IN_LEN:(
            cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
                                ctx=args.ctx[0]) / 255.0
        train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                   encoder_net=encoder_net,
                   forecaster_net=forecaster_net,
                   loss_net=loss_net,
                   init_states=states,
                   data_nd=data_nd,
                   gt_nd=target_nd,
                   mask_nd=None,
                   iter_id=iter_id)
        if (iter_id + 1) % 100 == 0:
            new_frame_dat, _ = mnist_iter.sample(
                batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN)
            data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...],
                                  ctx=args.ctx[0]) / 255.0
            target_nd = mx.nd.array(frame_dat[cfg.MOVINGMNIST.IN_LEN:(
                cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
                                    ctx=args.ctx[0]) / 255.0
            pred_nd = mnist_get_prediction(data_nd=data_nd,
                                           states=states,
                                           encoder_net=encoder_net,
                                           forecaster_net=forecaster_net)
            save_gif(pred_nd.asnumpy()[:, 0, 0, :, :],
                     os.path.join(base_dir, "pred.gif"))
            save_gif(data_nd.asnumpy()[:, 0, 0, :, :],
                     os.path.join(base_dir, "in.gif"))
            save_gif(target_nd.asnumpy()[:, 0, 0, :, :],
                     os.path.join(base_dir, "gt.gif"))
        if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:
            encoder_net.save_checkpoint(prefix=os.path.join(
                base_dir, "encoder_net"),
                                        epoch=iter_id,
                                        save_optimizer_states=True)
            forecaster_net.save_checkpoint(prefix=os.path.join(
                base_dir, "forecaster_net"),
                                           epoch=iter_id,
                                           save_optimizer_states=True)
        iter_id += 1