Exemplo n.º 1
0
def test(args):
    assert (args.resume is True)
    if cfg.MODEL.TEST.FINETUNE:
        assert (cfg.MODEL.DECONVBASELINE.BN_GLOBAL_STATS is True)

    base_dir = args.save_dir
    logging_config(folder=base_dir, name="testing")
    save_cfg(dir_path=base_dir, source=cfg.MODEL)

    generator_net, loss_net, = construct_modules(args)

    if args.dataset == "test":
        pd_path = cfg.HKO_PD.RAINY_TEST
    elif args.dataset == "valid":
        pd_path = cfg.HKO_PD.RAINY_VALID
    else:
        raise NotImplementedError

    hko_benchmark(
        ctx=args.ctx[0],
        generator_net=generator_net,
        loss_net=loss_net,
        sample_num=1,
        save_dir=os.path.join(base_dir, "iter{}_{}_finetune{}".format(
            cfg.MODEL.LOAD_ITER + 1, args.dataset, cfg.MODEL.TEST.FINETUNE)),
        finetune=cfg.MODEL.TEST.FINETUNE,
        mode=cfg.MODEL.TEST.MODE,
        pd_path=pd_path)
Exemplo n.º 2
0
def test_sst(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    assert len(cfg.MODEL.LOAD_DIR) > 0
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="testing")
    save_cfg(dir_path=base_dir, source=cfg.MODEL)
    sst_nowcasting_online = SSTNowcastingFactory(batch_size=1,
                                                 in_seq_len=cfg.MODEL.IN_LEN,
                                                 out_seq_len=cfg.MODEL.OUT_LEN)
    t_encoder_net, t_forecaster_net, t_loss_net = encoder_forecaster_build_networks(
        factory=sst_nowcasting_online, context=args.ctx, for_finetune=True)
    t_encoder_net.summary()
    t_forecaster_net.summary()
    t_loss_net.summary()
    load_encoder_forecaster_params(
        load_dir=cfg.MODEL.LOAD_DIR,
        load_iter=cfg.MODEL.LOAD_ITER,
        encoder_net=t_encoder_net,
        forecaster_net=t_forecaster_net,
    )
    if args.dataset == "test":
        pd_path = cfg.SST_PD.RAINY_TEST
    elif args.dataset == "valid":
        pd_path = cfg.SST_PD.RAINY_VALID
    else:
        raise NotImplementedError
    run_benchmark(
        sst_factory=sst_nowcasting_online,
        context=args.ctx[0],
        encoder_net=t_encoder_net,
        forecaster_net=t_forecaster_net,
        loss_net=t_loss_net,
        save_dir=os.path.join(
            base_dir,
            "iter%d_%s_finetune%d" %
            (cfg.MODEL.LOAD_ITER + 1, args.dataset, cfg.MODEL.TEST.FINETUNE),
        ),
        finetune=cfg.MODEL.TEST.FINETUNE,
        mode=cfg.MODEL.TEST.MODE,
        pd_path=pd_path,
    )
Exemplo n.º 3
0
def argument_parser():
    parser = argparse.ArgumentParser(
        description='Deconvolution baseline for HKO')

    cfg.DATASET = "HKO"

    mode_args(parser)
    training_args(parser)
    dataset_args(parser)
    model_args(parser)

    args = parser.parse_args()

    parse_mode_args(args)
    parse_training_args(args)
    parse_model_args(args)

    base_dir = get_base_dir(args)
    logging_config(folder=base_dir, name="training")
    save_cfg(base_dir, source=cfg.MODEL)

    logging.info(args)
    return args
Exemplo n.º 4
0
def train(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = get_base_dir(args)
    logging_config(folder=base_dir)
    save_cfg(dir_path=base_dir, source=cfg.MODEL)
    if cfg.MODEL.TRAIN.TBPTT:
        # Create a set of sequent iterators with different starting point
        train_hko_iters = []
        train_hko_iter_restart = []
        for _ in range(cfg.MODEL.TRAIN.BATCH_SIZE):
            ele_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TRAIN,
                                   sample_mode="sequent",
                                   seq_len=cfg.MODEL.IN_LEN +
                                   cfg.MODEL.OUT_LEN,
                                   stride=cfg.MODEL.IN_LEN)
            ele_iter.random_reset()
            train_hko_iter_restart.append(True)
            train_hko_iters.append(ele_iter)
    else:
        train_hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TRAIN,
                                     sample_mode="random",
                                     seq_len=cfg.MODEL.IN_LEN +
                                     cfg.MODEL.OUT_LEN)

    hko_nowcasting = HKONowcastingFactory(
        batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
        ctx_num=len(args.ctx),
        in_seq_len=cfg.MODEL.IN_LEN,
        out_seq_len=cfg.MODEL.OUT_LEN)
    hko_nowcasting_online = HKONowcastingFactory(batch_size=1,
                                                 in_seq_len=cfg.MODEL.IN_LEN,
                                                 out_seq_len=cfg.MODEL.OUT_LEN)
    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=hko_nowcasting,
            context=args.ctx)
    t_encoder_net, t_forecaster_net, t_loss_net = \
        encoder_forecaster_build_networks(
            factory=hko_nowcasting_online,
            context=args.ctx[0],
            shared_encoder_net=encoder_net,
            shared_forecaster_net=forecaster_net,
            shared_loss_net=loss_net,
            for_finetune=True)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # Begin to load the model if load_dir is not empty
    if len(cfg.MODEL.LOAD_DIR) > 0:
        load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR,
                                       load_iter=cfg.MODEL.LOAD_ITER,
                                       encoder_net=encoder_net,
                                       forecaster_net=forecaster_net)
    states = EncoderForecasterStates(factory=hko_nowcasting, ctx=args.ctx[0])
    for info in hko_nowcasting.init_encoder_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    for info in hko_nowcasting.init_forecaster_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    test_mode = "online" if cfg.MODEL.TRAIN.TBPTT else "fixed"
    iter_id = 0
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        if not cfg.MODEL.TRAIN.TBPTT:
            # We are not using TBPTT, we could directly sample a random minibatch
            frame_dat, mask_dat, datetime_clips, _ = \
                train_hko_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE)
            states.reset_all()
        else:
            # We are using TBPTT, we should sample minibatches from the iterators.
            frame_dat_l = []
            mask_dat_l = []
            for i, ele_iter in enumerate(train_hko_iters):
                if ele_iter.use_up:
                    states.reset_batch(batch_id=i)
                    ele_iter.random_reset()
                    train_hko_iter_restart[i] = True
                if train_hko_iter_restart[
                        i] == False and ele_iter.check_new_start():
                    states.reset_batch(batch_id=i)
                    ele_iter.random_reset()
                frame_dat, mask_dat, datetime_clips, _ = \
                    ele_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE)
                train_hko_iter_restart[i] = False
                frame_dat_l.append(frame_dat)
                mask_dat_l.append(mask_dat)
            frame_dat = np.concatenate(frame_dat_l, axis=1)
            mask_dat = np.concatenate(mask_dat_l, axis=1)
        data_nd = mx.nd.array(frame_dat[0:cfg.MODEL.IN_LEN, ...],
                              ctx=args.ctx[0]) / 255.0
        target_nd = mx.nd.array(frame_dat[cfg.MODEL.IN_LEN:(
            cfg.MODEL.IN_LEN + cfg.MODEL.OUT_LEN), ...],
                                ctx=args.ctx[0]) / 255.0
        mask_nd = mx.nd.array(mask_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN +
                                                         cfg.MODEL.OUT_LEN),
                                       ...],
                              ctx=args.ctx[0])
        states, _ = train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                               encoder_net=encoder_net,
                               forecaster_net=forecaster_net,
                               loss_net=loss_net,
                               init_states=states,
                               data_nd=data_nd,
                               gt_nd=target_nd,
                               mask_nd=mask_nd,
                               iter_id=iter_id)
        if (iter_id + 1) % cfg.MODEL.VALID_ITER == 0:
            run_benchmark(hko_factory=hko_nowcasting_online,
                          context=args.ctx[0],
                          encoder_net=t_encoder_net,
                          forecaster_net=t_forecaster_net,
                          loss_net=t_loss_net,
                          save_dir=os.path.join(base_dir, "iter%d_valid" %
                                                (iter_id + 1)),
                          mode=test_mode,
                          pd_path=cfg.HKO_PD.RAINY_VALID)
        if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:
            encoder_net.save_checkpoint(prefix=os.path.join(
                base_dir, "encoder_net"),
                                        epoch=iter_id)
            forecaster_net.save_checkpoint(prefix=os.path.join(
                base_dir, "forecaster_net"),
                                           epoch=iter_id)
        iter_id += 1
def save_movingmnist_cfg(dir_path):
    tmp_cfg = edict()
    tmp_cfg.MOVINGMNIST = cfg.MOVINGMNIST
    tmp_cfg.MODEL = cfg.MODEL
    save_cfg(dir_path=dir_path, source=tmp_cfg)
Exemplo n.º 6
0
def train(args):
    base_dir = get_base_dir(args)
    logging_config(folder=base_dir)
    save_cfg(dir_path=base_dir, source=cfg.MODEL)

    use_gan = (cfg.MODEL.PROBLEM_FORM == 'regression') and (cfg.MODEL.GAN_G_LAMBDA > 0.0)
    if cfg.MODEL.FRAME_SKIP_OUT > 1:
        # here assume the target span the last 30 frames 
        iter_outlen = cfg.MODEL.OUT_LEN
    else:
        iter_outlen = 30
    model_outlen = cfg.MODEL.OUT_LEN  # training on 30 costs too much memory

    train_szo_iter = SZOIterator(rec_paths=cfg.SZO_TRAIN_DATA_PATHS,
                                in_len=cfg.MODEL.IN_LEN,
                                out_len=iter_outlen,  # iterator has to provide full output sequence as target if needed
                                batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                frame_skip_in=cfg.MODEL.FRAME_SKIP_IN,
                                frame_skip_out=cfg.MODEL.FRAME_SKIP_OUT, 
                                ctx=args.ctx)
    valid_szo_iter = SZOIterator(rec_paths=cfg.SZO_TEST_DATA_PATHS,
                                in_len=cfg.MODEL.IN_LEN,
                                out_len=iter_outlen,
                                batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                frame_skip_in=cfg.MODEL.FRAME_SKIP_IN,
                                frame_skip_out=cfg.MODEL.FRAME_SKIP_OUT,
                                ctx=args.ctx)
        
    szo_nowcasting = SZONowcastingFactory(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
                                          ctx_num=len(args.ctx),
                                          in_seq_len=cfg.MODEL.IN_LEN,
                                          out_seq_len=model_outlen,  # model still generate cfg.MODEL.OUT_LEN number of outputs at a time
                                          frame_stack=cfg.MODEL.FRAME_STACK)
    # discrim_net, loss_D_net will be None if use_gan = False
    encoder_net, forecaster_net, loss_net, discrim_net, loss_D_net = \
        encoder_forecaster_build_networks(
            factory=szo_nowcasting,
            context=args.ctx)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    if use_gan:
        discrim_net.summary()
        loss_D_net.summary()
    if use_gan:
        loss_types = ('mse','gdl','gan','dis')
    else:
        if cfg.MODEL.PROBLEM_FORM == 'regression':
            loss_types = ('mse', 'gdl')
        elif cfg.MODEL.PROBLEM_FORM == 'classification':
            loss_types = ('ce',)
        else:
            raise NotImplementedError
    # try to load checkpoint
    if args.resume:
        start_iter_id = latest_iter_id(base_dir)
        encoder_net.load_params(os.path.join(base_dir, 'encoder_net'+'-%04d.params'%(start_iter_id)))
        forecaster_net.load_params(os.path.join(base_dir, 'forecaster_net'+'-%04d.params'%(start_iter_id)))
        synchronize_kvstore(encoder_net)
        synchronize_kvstore(forecaster_net)
        if not args.resume_param_only:
            encoder_net.load_optimizer_states(os.path.join(base_dir, 'encoder_net'+'-%04d.states'%(start_iter_id)))
            forecaster_net.load_optimizer_states(os.path.join(base_dir, 'forecaster_net'+'-%04d.states'%(start_iter_id)))
        if use_gan:
            discrim_net.load_params(os.path.join(base_dir, 'discrim_net'+'-%04d.params'%(start_iter_id)))
            synchronize_kvstore(discrim_net)
            if not args.resume_param_only:
                discrim_net.load_optimizer_states(os.path.join(base_dir, 'discrim_net'+'-%04d.states'%(start_iter_id)))
            synchronize_kvstore(discrim_net)
    else:
        start_iter_id = -1

    if args.resume and (not args.resume_param_only):
        with open(os.path.join(base_dir, 'train_loss_dicts.pkl'), 'rb') as f:
            train_loss_dicts = pickle.load(f)
        with open(os.path.join(base_dir, 'valid_loss_dicts.pkl'), 'rb') as f:
            valid_loss_dicts = pickle.load(f)
        for dicts in (train_loss_dicts, valid_loss_dicts):
            keys_to_delete = []
            keys_to_add = []
            key_len = 0
            for k in dicts.keys():
                key_len = len(dicts[k])
                if k not in loss_types:
                    keys_to_delete.append(k)
            for k in keys_to_delete:
                del dicts[k]
            for k in loss_types:
                if k not in dicts.keys():
                    dicts[k] = [0] * key_len
    else:
        train_loss_dicts = {}
        valid_loss_dicts = {}
        for dicts in (train_loss_dicts, valid_loss_dicts):
            for typ in loss_types:
                dicts[typ] = []

    states = EncoderForecasterStates(factory=szo_nowcasting, ctx=args.ctx[0])
    for info in szo_nowcasting.init_encoder_state_info:
        assert info["__layout__"].find('N') == 0, "Layout=%s is not supported!" %info["__layout__"]
    for info in szo_nowcasting.init_forecaster_state_info:
        assert info["__layout__"].find('N') == 0, "Layout=%s is not supported!" % info["__layout__"]

    cumulative_loss = {}
    for k in train_loss_dicts.keys():
        cumulative_loss[k] = 0.0

    iter_id = start_iter_id + 1
    fix_shift = True if cfg.MODEL.OPTFLOW_AS_INPUT  else False
    chan = cfg.SZO.DATA.IMAGE_CHANNEL if cfg.MODEL.OPTFLOW_AS_INPUT else 0
    buffers = {}
    buffers['fake'] = deque([], maxlen=cfg.MODEL.TRAIN.GEN_BUFFER_LEN)
    buffers['true'] = deque([], maxlen=cfg.MODEL.TRAIN.GEN_BUFFER_LEN)
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        frame_dat = train_szo_iter.sample(fix_shift=fix_shift)
        data_nd = frame_dat[0:cfg.MODEL.IN_LEN,:,:,:,:] / 255.0  # scale to [0,1]
        # only take the channel of grey scale
        target_nd = frame_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN + cfg.MODEL.OUT_LEN),:,chan,:,:] / 255.0
        target_nd = target_nd.expand_dims(axis=2)
        states.reset_all()
        if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:
            mask_nd = target_nd < 1.0
        else:
            mask_nd = mx.nd.ones_like(target_nd)
        
        states, loss_dict, pred_nd, buffers = train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                               encoder_net=encoder_net, forecaster_net=forecaster_net,
                               loss_net=loss_net, discrim_net=discrim_net, 
                               loss_D_net=loss_D_net, init_states=states,
                               data_nd=data_nd, gt_nd=target_nd, mask_nd=mask_nd,
                               factory=szo_nowcasting, iter_id=iter_id, buffers=buffers)
        for k in cumulative_loss.keys():
            loss = loss_dict[k+'_output']
            cumulative_loss[k] += loss

        if (iter_id+1) % cfg.MODEL.VALID_ITER == 0:
            for i in range(cfg.MODEL.VALID_LOOP):
                states.reset_all()            
                frame_dat_v = valid_szo_iter.sample(fix_shift=fix_shift)
                data_nd_v = frame_dat_v[0:cfg.MODEL.IN_LEN,:,:,:,:] / 255.0
                gt_nd_v = frame_dat_v[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN+cfg.MODEL.OUT_LEN),:,chan,:,:] / 255.0
                gt_nd_v = gt_nd_v.expand_dims(axis=2)
                if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:
                    mask_nd_v = gt_nd_v < 1.0
                else:
                    mask_nd_v = mx.nd.ones_like(gt_nd_v)
                states, new_valid_loss_dicts, pred_nd_v = valid_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                encoder_net=encoder_net, forecaster_net=forecaster_net,
                                loss_net=loss_net, discrim_net=discrim_net,
                                loss_D_net=loss_D_net, init_states=states,
                                data_nd=data_nd_v, gt_nd=gt_nd_v, mask_nd=mask_nd_v,
                                valid_loss_dicts=valid_loss_dicts, factory=szo_nowcasting, iter_id=iter_id)
                if i == 0:
                    for k in valid_loss_dicts.keys():
                        valid_loss_dicts[k].append(new_valid_loss_dicts[k])
                else:
                    for k in valid_loss_dicts.keys():
                        valid_loss_dicts[k][-1] += new_valid_loss_dicts[k]
                
            for k in valid_loss_dicts.keys():
                valid_loss_dicts[k][-1] /= cfg.MODEL.VALID_LOOP
                plot_loss_curve(os.path.join(base_dir, 'valid_'+k+'_loss'), valid_loss_dicts[k])
            
        if (iter_id+1) % cfg.MODEL.DRAW_EVERY == 0:
            for k in train_loss_dicts.keys():
                avg_loss = cumulative_loss[k] / cfg.MODEL.DRAW_EVERY
                if k=='gan':
                    avg_loss = avg_loss if avg_loss < 1.0 else 1.0
                elif k == 'dis':
                    if avg_loss > 1.0:
                        avg_loss = 1.0
                if avg_loss < 0:
                    avg_loss = 0
                train_loss_dicts[k].append(avg_loss)
                cumulative_loss[k] = 0.0
                plot_loss_curve(os.path.join(base_dir, 'train_'+k+'_loss'), train_loss_dicts[k])

        if (iter_id+1) % cfg.MODEL.DISPLAY_EVERY == 0:
            new_frame_dat = train_szo_iter.sample(fix_shift=fix_shift)
            data_nd_d = new_frame_dat[0:cfg.MODEL.IN_LEN,:,:,:,:] / 255.0
            target_nd_d = new_frame_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN + cfg.MODEL.OUT_LEN),:,chan,:,:] / 255.0
            target_nd_d = target_nd_d.expand_dims(axis=2)
            states.reset_all()
            pred_nd_d = get_prediction(data_nd_d, states, encoder_net, forecaster_net)
            if cfg.MODEL.PROBLEM_FORM == 'classification':
                pred_nd_d = from_class_to_image(pred_nd_d)
            display_path1 = os.path.join(base_dir, 'display_'+str(iter_id))
            display_path2 = os.path.join(base_dir, 'display_'+str(iter_id)+'_')
            if not os.path.exists(display_path1):
                os.mkdir(display_path1)
            if not os.path.exists(display_path2):
                os.mkdir(display_path2)

            data_nd_d = (data_nd_d*255.0).clip(0, 255.0)
            target_nd_d = (target_nd_d*255.0).clip(0, 255.0)
            pred_nd_d = (pred_nd_d*255.0).clip(0, 255.0)
            save_prediction(data_nd_d[:,0,chan,:,:], target_nd_d[:,0,0,:,:], pred_nd_d[:,0,0,:,:], display_path1, default_as_0=True)
            save_prediction(data_nd_d[:,0,chan,:,:], target_nd_d[:,0,0,:,:], pred_nd_d[:,0,0,:,:], display_path2, default_as_0=False)
        if (iter_id+1) % cfg.MODEL.SAVE_ITER == 0:
            encoder_net.save_checkpoint(
                prefix=os.path.join(base_dir, "encoder_net",),
                epoch=iter_id,
                save_optimizer_states=True)
            forecaster_net.save_checkpoint(
                prefix=os.path.join(base_dir, "forecaster_net",),
                epoch=iter_id,
                save_optimizer_states=True)
            if cfg.MODEL.GAN_G_LAMBDA > 0:
                discrim_net.save_checkpoint(
                    prefix=os.path.join(base_dir, "discrim_net",),
                    epoch=iter_id,
                    save_optimizer_states=True)
            path1 = os.path.join(base_dir, 'train_loss_dicts.pkl')
            path2 = os.path.join(base_dir, 'valid_loss_dicts.pkl')
            with open(path1, 'wb') as f:
                pickle.dump(train_loss_dicts, f)
            with open(path2, 'wb') as f:
                pickle.dump(valid_loss_dicts, f)
        iter_id += 1
Exemplo n.º 7
0
def main(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="train")
    save_cfg(dir_path=base_dir, source=cfg.MODEL)
    metadata_file = os.path.join(
        args.data_dir,
        'hdf_metadata.csv') if args.data_csv is None else args.data_csv

    all_data = h5py.File(os.path.join(args.data_dir, 'hdf_archives',
                                      'all_data.hdf5'),
                         'r',
                         libver='latest')
    outlier_mask = cv2.imread(os.path.join(args.data_dir, 'mask.png'), 0)

    metadata = pd.read_csv(metadata_file, index_col='id')
    metadata['start_datetime'] = pd.to_datetime(metadata['start_datetime'])
    metadata['end_datetime'] = pd.to_datetime(metadata['end_datetime'])
    if args.date_start is not None:
        metadata = metadata.loc[metadata['start_datetime'] >= args.date_start]
    if args.date_end is not None:
        metadata = metadata.loc[metadata['start_datetime'] < args.date_end]
    sort_meta = metadata.sample(frac=1)
    split_idx = int(len(sort_meta) * 0.95)
    train_meta = sort_meta.iloc[:split_idx]
    test_meta = sort_meta.iloc[split_idx:]

    logging.info("Initializing data iterator with filter threshold %s" %
                 cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD)
    train_model_iter = infinite_batcher(
        all_data,
        train_meta,
        outlier_mask,
        shuffle=False,
        filter_threshold=cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD)

    model_nowcasting = RadarNowcastingFactory(
        batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
        ctx_num=len(args.ctx),
        in_seq_len=cfg.MODEL.IN_LEN,
        out_seq_len=cfg.MODEL.OUT_LEN)
    model_nowcasting_online = RadarNowcastingFactory(
        batch_size=1,
        in_seq_len=cfg.MODEL.IN_LEN,
        out_seq_len=cfg.MODEL.OUT_LEN)
    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=model_nowcasting,
            context=args.ctx)
    t_encoder_net, t_forecaster_net, t_loss_net = \
        encoder_forecaster_build_networks(
            factory=model_nowcasting_online,
            context=args.ctx[0],
            shared_encoder_net=encoder_net,
            shared_forecaster_net=forecaster_net,
            shared_loss_net=loss_net,
            for_finetune=True)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # Begin to load the model if load_dir is not empty
    if len(cfg.MODEL.LOAD_DIR) > 0:
        load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR,
                                       load_iter=cfg.MODEL.LOAD_ITER,
                                       encoder_net=encoder_net,
                                       forecaster_net=forecaster_net)
    states = EncoderForecasterStates(factory=model_nowcasting, ctx=args.ctx[0])
    for info in model_nowcasting.init_encoder_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    for info in model_nowcasting.init_forecaster_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    test_mode = "online" if cfg.MODEL.TRAIN.TBPTT else "fixed"
    iter_id = 0
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        # sample a random minibatch
        try:
            frame_dat, _, mask_dat = next(train_model_iter)
        except StopIteration:
            break
        else:
            states.reset_all()
            data_nd = mx.nd.array(frame_dat[0:cfg.MODEL.IN_LEN, ...],
                                  ctx=args.ctx[0])
            target_nd = mx.nd.array(
                frame_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN +
                                            cfg.MODEL.OUT_LEN), ...],
                ctx=args.ctx[0])
            mask_nd = mx.nd.array(
                mask_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN +
                                           cfg.MODEL.OUT_LEN), ...],
                ctx=args.ctx[0])
            states, _ = train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                   encoder_net=encoder_net,
                                   forecaster_net=forecaster_net,
                                   loss_net=loss_net,
                                   init_states=states,
                                   data_nd=data_nd,
                                   gt_nd=target_nd,
                                   mask_nd=mask_nd,
                                   iter_id=iter_id)
            if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:
                encoder_net.save_checkpoint(prefix=os.path.join(
                    base_dir, "encoder_net"),
                                            epoch=iter_id)
                forecaster_net.save_checkpoint(prefix=os.path.join(
                    base_dir, "forecaster_net"),
                                               epoch=iter_id)
            if (iter_id + 1) % cfg.MODEL.VALID_ITER == 0:
                test_model_iter = HDFIterator(
                    all_data,
                    test_meta,
                    outlier_mask,
                    batch_size=1,
                    shuffle=False,
                    filter_threshold=cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD
                )
                run_benchmark(model_factory=model_nowcasting_online,
                              context=args.ctx[0],
                              encoder_net=t_encoder_net,
                              forecaster_net=t_forecaster_net,
                              save_dir=os.path.join(
                                  base_dir, "iter%d_valid" % (iter_id + 1)),
                              mode=test_mode,
                              batcher=test_model_iter)
            iter_id += 1
Exemplo n.º 8
0
def save_icdm_cfg(dir_path):
    tmp_cfg = edict()
    tmp_cfg.ICDM = cfg.ICDM
    tmp_cfg.MODEL = cfg.MODEL
    save_cfg(dir_path=dir_path, source=tmp_cfg)