Exemple #1
0
def train(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="training")
    save_icdm_cfg(base_dir)
    icdm_iter = ICDMIterator(root_dir=r'C:\Users\jing\projects\nowcasting\HKO-7\icdm_data\SRAD2018_TRAIN_010')
    icdm_factory = ICDMFactory(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
                                   in_seq_len=cfg.ICDM.IN_LEN,
                                   out_seq_len=cfg.ICDM.OUT_LEN)

    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=icdm_factory,
            context=args.ctx)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # Begin to load the model if load_dir is not empty
#    if len(cfg.MODEL.LOAD_DIR) > 0:
#        #load_mnist_params(load_dir=cfg.MODEL.LOAD_DIR, load_iter=cfg.MODEL.LOAD_ITER,
#                          encoder_net=encoder_net, forecaster_net=forecaster_net)
    states = EncoderForecasterStates(factory=icdm_factory, ctx=args.ctx[0])
    states.reset_all()
    for info in icdm_factory.init_encoder_state_info:
        assert info["__layout__"].find('N') == 0, "Layout=%s is not supported!" %info["__layout__"]
    iter_id = 0
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        input_batch, output_batch = icdm_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                         in_len=cfg.ICDM.IN_LEN, out_len=cfg.ICDM.OUT_LEN)
        # data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...], ctx=args.ctx[0]) / 255.0
        data_nd = mx.nd.array(input_batch, ctx=args.ctx[0], dtype=np.float32)
        target_nd = mx.nd.array(output_batch, ctx=args.ctx[0], dtype=np.float32)
        # target_nd = mx.nd.array(
        #     frame_dat[cfg.MODEL.IN_LEN:(cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
        #     ctx=args.ctx[0]) / 255.0
        train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                   encoder_net=encoder_net, forecaster_net=forecaster_net,
                   loss_net=loss_net, init_states=states,
                   data_nd=data_nd, gt_nd=target_nd, mask_nd=None,
                   iter_id=iter_id)
        print('step:', iter_id)
        if (iter_id+ 1) % 100 == 0:
            pass
        if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:
            encoder_net.save_checkpoint(
                prefix=os.path.join(base_dir, "encoder_net"),
                epoch=iter_id)
            forecaster_net.save_checkpoint(
                prefix=os.path.join(base_dir, "forecaster_net"),
                epoch=iter_id)
        iter_id += 1
Exemple #2
0
def test_sst(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    assert len(cfg.MODEL.LOAD_DIR) > 0
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="testing")
    save_cfg(dir_path=base_dir, source=cfg.MODEL)
    sst_nowcasting_online = SSTNowcastingFactory(batch_size=1,
                                                 in_seq_len=cfg.MODEL.IN_LEN,
                                                 out_seq_len=cfg.MODEL.OUT_LEN)
    t_encoder_net, t_forecaster_net, t_loss_net = encoder_forecaster_build_networks(
        factory=sst_nowcasting_online, context=args.ctx, for_finetune=True)
    t_encoder_net.summary()
    t_forecaster_net.summary()
    t_loss_net.summary()
    load_encoder_forecaster_params(
        load_dir=cfg.MODEL.LOAD_DIR,
        load_iter=cfg.MODEL.LOAD_ITER,
        encoder_net=t_encoder_net,
        forecaster_net=t_forecaster_net,
    )
    if args.dataset == "test":
        pd_path = cfg.SST_PD.RAINY_TEST
    elif args.dataset == "valid":
        pd_path = cfg.SST_PD.RAINY_VALID
    else:
        raise NotImplementedError
    run_benchmark(
        sst_factory=sst_nowcasting_online,
        context=args.ctx[0],
        encoder_net=t_encoder_net,
        forecaster_net=t_forecaster_net,
        loss_net=t_loss_net,
        save_dir=os.path.join(
            base_dir,
            "iter%d_%s_finetune%d" %
            (cfg.MODEL.LOAD_ITER + 1, args.dataset, cfg.MODEL.TEST.FINETUNE),
        ),
        finetune=cfg.MODEL.TEST.FINETUNE,
        mode=cfg.MODEL.TEST.MODE,
        pd_path=pd_path,
    )
Exemple #3
0
def train(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = get_base_dir(args)
    logging_config(folder=base_dir)
    save_cfg(dir_path=base_dir, source=cfg.MODEL)
    if cfg.MODEL.TRAIN.TBPTT:
        # Create a set of sequent iterators with different starting point
        train_hko_iters = []
        train_hko_iter_restart = []
        for _ in range(cfg.MODEL.TRAIN.BATCH_SIZE):
            ele_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TRAIN,
                                   sample_mode="sequent",
                                   seq_len=cfg.MODEL.IN_LEN +
                                   cfg.MODEL.OUT_LEN,
                                   stride=cfg.MODEL.IN_LEN)
            ele_iter.random_reset()
            train_hko_iter_restart.append(True)
            train_hko_iters.append(ele_iter)
    else:
        train_hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TRAIN,
                                     sample_mode="random",
                                     seq_len=cfg.MODEL.IN_LEN +
                                     cfg.MODEL.OUT_LEN)

    hko_nowcasting = HKONowcastingFactory(
        batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
        ctx_num=len(args.ctx),
        in_seq_len=cfg.MODEL.IN_LEN,
        out_seq_len=cfg.MODEL.OUT_LEN)
    hko_nowcasting_online = HKONowcastingFactory(batch_size=1,
                                                 in_seq_len=cfg.MODEL.IN_LEN,
                                                 out_seq_len=cfg.MODEL.OUT_LEN)
    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=hko_nowcasting,
            context=args.ctx)
    t_encoder_net, t_forecaster_net, t_loss_net = \
        encoder_forecaster_build_networks(
            factory=hko_nowcasting_online,
            context=args.ctx[0],
            shared_encoder_net=encoder_net,
            shared_forecaster_net=forecaster_net,
            shared_loss_net=loss_net,
            for_finetune=True)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # Begin to load the model if load_dir is not empty
    if len(cfg.MODEL.LOAD_DIR) > 0:
        load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR,
                                       load_iter=cfg.MODEL.LOAD_ITER,
                                       encoder_net=encoder_net,
                                       forecaster_net=forecaster_net)
    states = EncoderForecasterStates(factory=hko_nowcasting, ctx=args.ctx[0])
    for info in hko_nowcasting.init_encoder_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    for info in hko_nowcasting.init_forecaster_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    test_mode = "online" if cfg.MODEL.TRAIN.TBPTT else "fixed"
    iter_id = 0
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        if not cfg.MODEL.TRAIN.TBPTT:
            # We are not using TBPTT, we could directly sample a random minibatch
            frame_dat, mask_dat, datetime_clips, _ = \
                train_hko_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE)
            states.reset_all()
        else:
            # We are using TBPTT, we should sample minibatches from the iterators.
            frame_dat_l = []
            mask_dat_l = []
            for i, ele_iter in enumerate(train_hko_iters):
                if ele_iter.use_up:
                    states.reset_batch(batch_id=i)
                    ele_iter.random_reset()
                    train_hko_iter_restart[i] = True
                if train_hko_iter_restart[
                        i] == False and ele_iter.check_new_start():
                    states.reset_batch(batch_id=i)
                    ele_iter.random_reset()
                frame_dat, mask_dat, datetime_clips, _ = \
                    ele_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE)
                train_hko_iter_restart[i] = False
                frame_dat_l.append(frame_dat)
                mask_dat_l.append(mask_dat)
            frame_dat = np.concatenate(frame_dat_l, axis=1)
            mask_dat = np.concatenate(mask_dat_l, axis=1)
        data_nd = mx.nd.array(frame_dat[0:cfg.MODEL.IN_LEN, ...],
                              ctx=args.ctx[0]) / 255.0
        target_nd = mx.nd.array(frame_dat[cfg.MODEL.IN_LEN:(
            cfg.MODEL.IN_LEN + cfg.MODEL.OUT_LEN), ...],
                                ctx=args.ctx[0]) / 255.0
        mask_nd = mx.nd.array(mask_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN +
                                                         cfg.MODEL.OUT_LEN),
                                       ...],
                              ctx=args.ctx[0])
        states, _ = train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                               encoder_net=encoder_net,
                               forecaster_net=forecaster_net,
                               loss_net=loss_net,
                               init_states=states,
                               data_nd=data_nd,
                               gt_nd=target_nd,
                               mask_nd=mask_nd,
                               iter_id=iter_id)
        if (iter_id + 1) % cfg.MODEL.VALID_ITER == 0:
            run_benchmark(hko_factory=hko_nowcasting_online,
                          context=args.ctx[0],
                          encoder_net=t_encoder_net,
                          forecaster_net=t_forecaster_net,
                          loss_net=t_loss_net,
                          save_dir=os.path.join(base_dir, "iter%d_valid" %
                                                (iter_id + 1)),
                          mode=test_mode,
                          pd_path=cfg.HKO_PD.RAINY_VALID)
        if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:
            encoder_net.save_checkpoint(prefix=os.path.join(
                base_dir, "encoder_net"),
                                        epoch=iter_id)
            forecaster_net.save_checkpoint(prefix=os.path.join(
                base_dir, "forecaster_net"),
                                           epoch=iter_id)
        iter_id += 1
Exemple #4
0
def test(args, batches, checkpoint_id=None, on_train=False):
    if cfg.MODEL.FRAME_SKIP_OUT == 1:
        iter_outlen = 30
        model_outlen = 30
    elif cfg.MODEL.FRAME_SKIP_OUT == 5:
        iter_outlen = 6
        model_outlen = 6
    else:
        raise NotImplementedError
    evaluator = SZOEvaluation(6, False)
    base_dir = get_base_dir(args)
    logging.basicConfig(level=logging.INFO)
    if on_train:
        szo_iter = SZOIterator(rec_paths=cfg.SZO_TRAIN_DATA_PATHS,
                                in_len=cfg.MODEL.IN_LEN,
                                out_len=iter,
                                batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                frame_skip_in=cfg.MODEL.FRAME_SKIP_IN,
                                frame_skip_out=cfg.MODEL.FRAME_SKIP_OUT,
                                ctx=args.ctx)
    else:
        szo_iter = SZOIterator(rec_paths=cfg.SZO_TEST_DATA_PATHS,
                                in_len=cfg.MODEL.IN_LEN,
                                out_len=iter_outlen,
                                batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                frame_skip_in=cfg.MODEL.FRAME_SKIP_IN,
                                frame_skip_out=cfg.MODEL.FRAME_SKIP_OUT,
                                ctx=args.ctx)

    szo_nowcasting = SZONowcastingFactory(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
                                          ctx_num=len(args.ctx),
                                          in_seq_len=cfg.MODEL.IN_LEN,
                                          out_seq_len=model_outlen,
                                          frame_stack=cfg.MODEL.FRAME_STACK)
    encoder_net, forecaster_net, loss_net, discrim_net, loss_D_net = \
        encoder_forecaster_build_networks(
            factory=szo_nowcasting,
            context=args.ctx)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # try to load checkpoint
    if checkpoint_id == None:
        start_iter_id = latest_iter_id(base_dir)
    else:
        start_iter_id = checkpoint_id
    encoder_net.load_params(os.path.join(base_dir, 'encoder_net'+'-%04d.params'%(start_iter_id)))
    forecaster_net.load_params(os.path.join(base_dir, 'forecaster_net'+'-%04d.params'%(start_iter_id)))
    states = EncoderForecasterStates(factory=szo_nowcasting, ctx=args.ctx[0])
    fix_shift = True if cfg.MODEL.OPTFLOW_AS_INPUT else False
    chan = cfg.SZO.DATA.IMAGE_CHANNEL if cfg.MODEL.OPTFLOW_AS_INPUT else 0
    for i in range(batches):
        print('batch ', i)
        states.reset_all()
        frame_dat = szo_iter.sample(fix_shift=fix_shift)
        data_nd = frame_dat[0:cfg.MODEL.IN_LEN, :,:,:,:] / 255.0
        target_nd = frame_dat[cfg.MODEL.IN_LEN:,:,chan,:,:] / 255.0
        target_nd = target_nd.expand_dims(axis=2)
        pred_nd = get_prediction(data_nd, states, encoder_net, forecaster_net)
        if cfg.MODEL.PROBLEM_FORM == 'classification':
            pred_nd = from_class_to_image(pred_nd)
        # generate mask from target_nd
        if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:
            target_nd = target_nd * (255.0/80.0)
            pred_nd = pred_nd * (255.0/80.0)
        mask_nd = (target_nd > 0.0)*(pred_nd > (cfg.MODEL.DISPLAY_EPSILON/255.0))
        if cfg.MODEL.FRAME_SKIP_OUT == 1:
            target_nd = target_nd[4::5,:,:,:,:]
            pred_nd = pred_nd[4::5,:,:,:,:]
            mask_nd = mask_nd[4::5,:,:,:,:]
        evaluator.update(target_nd.asnumpy(), pred_nd.asnumpy(), mask_nd.asnumpy())
    evaluator.print_stat_readable()
    filename = 'test_result_%03d'%(start_iter_id)
    if on_train:
        filename += '_on_train.txt'
    else:
        filename += '.txt'
    evaluator.save_txt_readable(path=os.path.join(base_dir, filename))
Exemple #5
0
def predict(args, num_samples, save_path=None, mode='display', extend='none'):
    """
    mode can be either display or save
    under display mode, num_samples gifs and comparisons are saved.
    under save mode, num_samples sequence of pngs are saved in difference directories
    extend can be none, recursive or onetime
    """
    assert len(args.ctx) == 1
    base_dir = get_base_dir(args)
    chan = cfg.SZO.DATA.IMAGE_CHANNEL if cfg.MODEL.OPTFLOW_AS_INPUT else 0
    if extend == 'recursive':
        assert not cfg.MODEL.PROBLEM_FORM == 'classification', 'recursive generation is not allowed under classification mode'
        assert (cfg.MODEL.FRAME_SKIP_IN) == 1 and (cfg.MODEL.FRAME_SKIP_OUT==1), '"extend" should be "none" when frame_skip is not 1'
        iter_outlen = 30
        model_outlen = cfg.MODEL.OUT_LEN
    elif extend == 'onetime':
        assert (cfg.MODEL.FRAME_SKIP_IN) == 1 and (cfg.MODEL.FRAME_SKIP_OUT==1), '"extend" should be "none" when frame_skip is not 1'
        iter_outlen = 30
        model_outlen = 30
    else:
        iter_outlen = cfg.MODEL.OUT_LEN
        model_outlen = cfg.MODEL.OUT_LEN
    szo_iterator = SZOIterator(rec_paths=cfg.SZO_TEST_DATA_PATHS,
                               in_len=cfg.MODEL.IN_LEN,
                               out_len=iter_outlen,
                               batch_size=1,
                               frame_skip_in=cfg.MODEL.FRAME_SKIP_IN,
                               frame_skip_out=cfg.MODEL.FRAME_SKIP_OUT,
                               ctx=args.ctx)  # there can be no ground truth available
    no_gt = szo_iterator.no_gt()
    szo_nowcasting = SZONowcastingFactory(batch_size=1,
                                          ctx_num=1,
                                          in_seq_len=cfg.MODEL.IN_LEN,
                                          out_seq_len=model_outlen,
                                          frame_stack=cfg.MODEL.FRAME_STACK)

    encoder_net, forecaster_net, loss_net, discrim_net, loss_D_net = \
        encoder_forecaster_build_networks(
            factory=szo_nowcasting,
            context=args.ctx)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # load parameters
    # assume parameter files are available
    start_iter_id = latest_iter_id(base_dir)
    encoder_net.load_params(os.path.join(base_dir, 'encoder_net'+'-%04d.params'%(start_iter_id)))
    forecaster_net.load_params(os.path.join(base_dir, 'forecaster_net'+'-%04d.params'%(start_iter_id)))
    # initial states
    states = EncoderForecasterStates(factory=szo_nowcasting, ctx=args.ctx[0])
    # generate samples
    for i in range(num_samples):
        new_frame_dat, folder_names = szo_iterator.get_sample_name_pair(fix_shift=True)
        data_nd_d = new_frame_dat[0:cfg.MODEL.IN_LEN,:,:,:,:] / 255.0
        if not no_gt:
            target_nd_d = new_frame_dat[cfg.MODEL.IN_LEN:,:,chan,:,:] / 255.0
            target_nd_d = target_nd_d.expand_dims(axis=2)
        else:
            target_nd_d = None
        states.reset_all()
        pred_nd_d1 = get_prediction(data_nd_d, states, encoder_net, forecaster_net)
        if extend == 'recursive':
            states.reset_all()
            pred_nd_d2 = get_prediction(pred_nd_d1[30-cfg.MODEL.IN_LEN-cfg.MODEL.OUT_LEN: 30-cfg.MODEL.OUT_LEN,:,:,:,:], 
                                        states, encoder_net, forecaster_net)
            pred_nd_d = mx.nd.concat(pred_nd_d1[:30-cfg.MODEL.OUT_LEN,:,:,:,:],pred_nd_d2, dim=0)
        else:
            pred_nd_d = pred_nd_d1
        if cfg.MODEL.PROBLEM_FORM == 'classification':
            pred_nd_d = from_class_to_image(pred_nd_d)
        
        data_nd_d = (data_nd_d*255.0).clip(0, 255.0)
        target_nd_d = (target_nd_d*255.0).clip(0, 255.0) if not no_gt else None
        pred_nd_d = (pred_nd_d*255.0).clip(0, 255.0)

        if save_path is None:
            save_path = base_dir
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        if mode == 'display':
            display_path1 = os.path.join(save_path, 'prediction_'+str(i))
            display_path2 = os.path.join(save_path, 'prediction_'+str(i)+'_')
            if not os.path.exists(display_path1):
                os.mkdir(display_path1)
            if not os.path.exists(display_path2):
                os.mkdir(display_path2)

            if not no_gt:
                save_prediction(data_nd_d[:,0,chan,:,:], target_nd_d[:,0,0,:,:], pred_nd_d[:,0,0,:,:], display_path1, default_as_0=True)
                save_prediction(data_nd_d[:,0,chan,:,:], target_nd_d[:,0,0,:,:], pred_nd_d[:,0,0,:,:], display_path2, default_as_0=False)
            else:
                save_prediction(data_nd_d[:,0,chan,:,:], None, pred_nd_d[:,0,0,:,:], display_path1, default_as_0=True)
                save_prediction(data_nd_d[:,0,chan,:,:], None, pred_nd_d[:,0,0,:,:], display_path2, default_as_0=False)

            plt.hist(pred_nd_d.asnumpy().reshape([-1]), bins=100)
            plt.savefig(os.path.join(base_dir, 'hist'+str(i)))
            plt.close('all')
        elif mode == 'save':
            gt_path = os.path.join(save_path, 'groundtruth')
            pred_path = os.path.join(save_path, 'prediction')
            if (not no_gt) and (not os.path.exists(gt_path)):
                os.mkdir(gt_path)
            if not os.path.exists(pred_path):
                os.mkdir(pred_path)
            folder_name = folder_names[0][-1]

            if not no_gt:
                save_prediction(mx.nd.zeros([1]), target_nd_d[::5,0,0,:,:], pred_nd_d[::5,0,0,:,:], None, default_as_0=False, mode='save', folder_name=folder_name, gt_path=gt_path, pred_path=pred_path)
            else:
                save_prediction(mx.nd.zeros([1]), None, pred_nd_d[::5,0,0,:,:], None, default_as_0=False, mode='save', folder_name=folder_name, gt_path=gt_path, pred_path=pred_path)

        else:
            raise NotImplementedError
Exemple #6
0
def train(args):
    base_dir = get_base_dir(args)
    logging_config(folder=base_dir)
    save_cfg(dir_path=base_dir, source=cfg.MODEL)

    use_gan = (cfg.MODEL.PROBLEM_FORM == 'regression') and (cfg.MODEL.GAN_G_LAMBDA > 0.0)
    if cfg.MODEL.FRAME_SKIP_OUT > 1:
        # here assume the target span the last 30 frames 
        iter_outlen = cfg.MODEL.OUT_LEN
    else:
        iter_outlen = 30
    model_outlen = cfg.MODEL.OUT_LEN  # training on 30 costs too much memory

    train_szo_iter = SZOIterator(rec_paths=cfg.SZO_TRAIN_DATA_PATHS,
                                in_len=cfg.MODEL.IN_LEN,
                                out_len=iter_outlen,  # iterator has to provide full output sequence as target if needed
                                batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                frame_skip_in=cfg.MODEL.FRAME_SKIP_IN,
                                frame_skip_out=cfg.MODEL.FRAME_SKIP_OUT, 
                                ctx=args.ctx)
    valid_szo_iter = SZOIterator(rec_paths=cfg.SZO_TEST_DATA_PATHS,
                                in_len=cfg.MODEL.IN_LEN,
                                out_len=iter_outlen,
                                batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                frame_skip_in=cfg.MODEL.FRAME_SKIP_IN,
                                frame_skip_out=cfg.MODEL.FRAME_SKIP_OUT,
                                ctx=args.ctx)
        
    szo_nowcasting = SZONowcastingFactory(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
                                          ctx_num=len(args.ctx),
                                          in_seq_len=cfg.MODEL.IN_LEN,
                                          out_seq_len=model_outlen,  # model still generate cfg.MODEL.OUT_LEN number of outputs at a time
                                          frame_stack=cfg.MODEL.FRAME_STACK)
    # discrim_net, loss_D_net will be None if use_gan = False
    encoder_net, forecaster_net, loss_net, discrim_net, loss_D_net = \
        encoder_forecaster_build_networks(
            factory=szo_nowcasting,
            context=args.ctx)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    if use_gan:
        discrim_net.summary()
        loss_D_net.summary()
    if use_gan:
        loss_types = ('mse','gdl','gan','dis')
    else:
        if cfg.MODEL.PROBLEM_FORM == 'regression':
            loss_types = ('mse', 'gdl')
        elif cfg.MODEL.PROBLEM_FORM == 'classification':
            loss_types = ('ce',)
        else:
            raise NotImplementedError
    # try to load checkpoint
    if args.resume:
        start_iter_id = latest_iter_id(base_dir)
        encoder_net.load_params(os.path.join(base_dir, 'encoder_net'+'-%04d.params'%(start_iter_id)))
        forecaster_net.load_params(os.path.join(base_dir, 'forecaster_net'+'-%04d.params'%(start_iter_id)))
        synchronize_kvstore(encoder_net)
        synchronize_kvstore(forecaster_net)
        if not args.resume_param_only:
            encoder_net.load_optimizer_states(os.path.join(base_dir, 'encoder_net'+'-%04d.states'%(start_iter_id)))
            forecaster_net.load_optimizer_states(os.path.join(base_dir, 'forecaster_net'+'-%04d.states'%(start_iter_id)))
        if use_gan:
            discrim_net.load_params(os.path.join(base_dir, 'discrim_net'+'-%04d.params'%(start_iter_id)))
            synchronize_kvstore(discrim_net)
            if not args.resume_param_only:
                discrim_net.load_optimizer_states(os.path.join(base_dir, 'discrim_net'+'-%04d.states'%(start_iter_id)))
            synchronize_kvstore(discrim_net)
    else:
        start_iter_id = -1

    if args.resume and (not args.resume_param_only):
        with open(os.path.join(base_dir, 'train_loss_dicts.pkl'), 'rb') as f:
            train_loss_dicts = pickle.load(f)
        with open(os.path.join(base_dir, 'valid_loss_dicts.pkl'), 'rb') as f:
            valid_loss_dicts = pickle.load(f)
        for dicts in (train_loss_dicts, valid_loss_dicts):
            keys_to_delete = []
            keys_to_add = []
            key_len = 0
            for k in dicts.keys():
                key_len = len(dicts[k])
                if k not in loss_types:
                    keys_to_delete.append(k)
            for k in keys_to_delete:
                del dicts[k]
            for k in loss_types:
                if k not in dicts.keys():
                    dicts[k] = [0] * key_len
    else:
        train_loss_dicts = {}
        valid_loss_dicts = {}
        for dicts in (train_loss_dicts, valid_loss_dicts):
            for typ in loss_types:
                dicts[typ] = []

    states = EncoderForecasterStates(factory=szo_nowcasting, ctx=args.ctx[0])
    for info in szo_nowcasting.init_encoder_state_info:
        assert info["__layout__"].find('N') == 0, "Layout=%s is not supported!" %info["__layout__"]
    for info in szo_nowcasting.init_forecaster_state_info:
        assert info["__layout__"].find('N') == 0, "Layout=%s is not supported!" % info["__layout__"]

    cumulative_loss = {}
    for k in train_loss_dicts.keys():
        cumulative_loss[k] = 0.0

    iter_id = start_iter_id + 1
    fix_shift = True if cfg.MODEL.OPTFLOW_AS_INPUT  else False
    chan = cfg.SZO.DATA.IMAGE_CHANNEL if cfg.MODEL.OPTFLOW_AS_INPUT else 0
    buffers = {}
    buffers['fake'] = deque([], maxlen=cfg.MODEL.TRAIN.GEN_BUFFER_LEN)
    buffers['true'] = deque([], maxlen=cfg.MODEL.TRAIN.GEN_BUFFER_LEN)
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        frame_dat = train_szo_iter.sample(fix_shift=fix_shift)
        data_nd = frame_dat[0:cfg.MODEL.IN_LEN,:,:,:,:] / 255.0  # scale to [0,1]
        # only take the channel of grey scale
        target_nd = frame_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN + cfg.MODEL.OUT_LEN),:,chan,:,:] / 255.0
        target_nd = target_nd.expand_dims(axis=2)
        states.reset_all()
        if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:
            mask_nd = target_nd < 1.0
        else:
            mask_nd = mx.nd.ones_like(target_nd)
        
        states, loss_dict, pred_nd, buffers = train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                               encoder_net=encoder_net, forecaster_net=forecaster_net,
                               loss_net=loss_net, discrim_net=discrim_net, 
                               loss_D_net=loss_D_net, init_states=states,
                               data_nd=data_nd, gt_nd=target_nd, mask_nd=mask_nd,
                               factory=szo_nowcasting, iter_id=iter_id, buffers=buffers)
        for k in cumulative_loss.keys():
            loss = loss_dict[k+'_output']
            cumulative_loss[k] += loss

        if (iter_id+1) % cfg.MODEL.VALID_ITER == 0:
            for i in range(cfg.MODEL.VALID_LOOP):
                states.reset_all()            
                frame_dat_v = valid_szo_iter.sample(fix_shift=fix_shift)
                data_nd_v = frame_dat_v[0:cfg.MODEL.IN_LEN,:,:,:,:] / 255.0
                gt_nd_v = frame_dat_v[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN+cfg.MODEL.OUT_LEN),:,chan,:,:] / 255.0
                gt_nd_v = gt_nd_v.expand_dims(axis=2)
                if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:
                    mask_nd_v = gt_nd_v < 1.0
                else:
                    mask_nd_v = mx.nd.ones_like(gt_nd_v)
                states, new_valid_loss_dicts, pred_nd_v = valid_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                encoder_net=encoder_net, forecaster_net=forecaster_net,
                                loss_net=loss_net, discrim_net=discrim_net,
                                loss_D_net=loss_D_net, init_states=states,
                                data_nd=data_nd_v, gt_nd=gt_nd_v, mask_nd=mask_nd_v,
                                valid_loss_dicts=valid_loss_dicts, factory=szo_nowcasting, iter_id=iter_id)
                if i == 0:
                    for k in valid_loss_dicts.keys():
                        valid_loss_dicts[k].append(new_valid_loss_dicts[k])
                else:
                    for k in valid_loss_dicts.keys():
                        valid_loss_dicts[k][-1] += new_valid_loss_dicts[k]
                
            for k in valid_loss_dicts.keys():
                valid_loss_dicts[k][-1] /= cfg.MODEL.VALID_LOOP
                plot_loss_curve(os.path.join(base_dir, 'valid_'+k+'_loss'), valid_loss_dicts[k])
            
        if (iter_id+1) % cfg.MODEL.DRAW_EVERY == 0:
            for k in train_loss_dicts.keys():
                avg_loss = cumulative_loss[k] / cfg.MODEL.DRAW_EVERY
                if k=='gan':
                    avg_loss = avg_loss if avg_loss < 1.0 else 1.0
                elif k == 'dis':
                    if avg_loss > 1.0:
                        avg_loss = 1.0
                if avg_loss < 0:
                    avg_loss = 0
                train_loss_dicts[k].append(avg_loss)
                cumulative_loss[k] = 0.0
                plot_loss_curve(os.path.join(base_dir, 'train_'+k+'_loss'), train_loss_dicts[k])

        if (iter_id+1) % cfg.MODEL.DISPLAY_EVERY == 0:
            new_frame_dat = train_szo_iter.sample(fix_shift=fix_shift)
            data_nd_d = new_frame_dat[0:cfg.MODEL.IN_LEN,:,:,:,:] / 255.0
            target_nd_d = new_frame_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN + cfg.MODEL.OUT_LEN),:,chan,:,:] / 255.0
            target_nd_d = target_nd_d.expand_dims(axis=2)
            states.reset_all()
            pred_nd_d = get_prediction(data_nd_d, states, encoder_net, forecaster_net)
            if cfg.MODEL.PROBLEM_FORM == 'classification':
                pred_nd_d = from_class_to_image(pred_nd_d)
            display_path1 = os.path.join(base_dir, 'display_'+str(iter_id))
            display_path2 = os.path.join(base_dir, 'display_'+str(iter_id)+'_')
            if not os.path.exists(display_path1):
                os.mkdir(display_path1)
            if not os.path.exists(display_path2):
                os.mkdir(display_path2)

            data_nd_d = (data_nd_d*255.0).clip(0, 255.0)
            target_nd_d = (target_nd_d*255.0).clip(0, 255.0)
            pred_nd_d = (pred_nd_d*255.0).clip(0, 255.0)
            save_prediction(data_nd_d[:,0,chan,:,:], target_nd_d[:,0,0,:,:], pred_nd_d[:,0,0,:,:], display_path1, default_as_0=True)
            save_prediction(data_nd_d[:,0,chan,:,:], target_nd_d[:,0,0,:,:], pred_nd_d[:,0,0,:,:], display_path2, default_as_0=False)
        if (iter_id+1) % cfg.MODEL.SAVE_ITER == 0:
            encoder_net.save_checkpoint(
                prefix=os.path.join(base_dir, "encoder_net",),
                epoch=iter_id,
                save_optimizer_states=True)
            forecaster_net.save_checkpoint(
                prefix=os.path.join(base_dir, "forecaster_net",),
                epoch=iter_id,
                save_optimizer_states=True)
            if cfg.MODEL.GAN_G_LAMBDA > 0:
                discrim_net.save_checkpoint(
                    prefix=os.path.join(base_dir, "discrim_net",),
                    epoch=iter_id,
                    save_optimizer_states=True)
            path1 = os.path.join(base_dir, 'train_loss_dicts.pkl')
            path2 = os.path.join(base_dir, 'valid_loss_dicts.pkl')
            with open(path1, 'wb') as f:
                pickle.dump(train_loss_dicts, f)
            with open(path2, 'wb') as f:
                pickle.dump(valid_loss_dicts, f)
        iter_id += 1
Exemple #7
0
def main(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="train")
    save_cfg(dir_path=base_dir, source=cfg.MODEL)
    metadata_file = os.path.join(
        args.data_dir,
        'hdf_metadata.csv') if args.data_csv is None else args.data_csv

    all_data = h5py.File(os.path.join(args.data_dir, 'hdf_archives',
                                      'all_data.hdf5'),
                         'r',
                         libver='latest')
    outlier_mask = cv2.imread(os.path.join(args.data_dir, 'mask.png'), 0)

    metadata = pd.read_csv(metadata_file, index_col='id')
    metadata['start_datetime'] = pd.to_datetime(metadata['start_datetime'])
    metadata['end_datetime'] = pd.to_datetime(metadata['end_datetime'])
    if args.date_start is not None:
        metadata = metadata.loc[metadata['start_datetime'] >= args.date_start]
    if args.date_end is not None:
        metadata = metadata.loc[metadata['start_datetime'] < args.date_end]
    sort_meta = metadata.sample(frac=1)
    split_idx = int(len(sort_meta) * 0.95)
    train_meta = sort_meta.iloc[:split_idx]
    test_meta = sort_meta.iloc[split_idx:]

    logging.info("Initializing data iterator with filter threshold %s" %
                 cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD)
    train_model_iter = infinite_batcher(
        all_data,
        train_meta,
        outlier_mask,
        shuffle=False,
        filter_threshold=cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD)

    model_nowcasting = RadarNowcastingFactory(
        batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
        ctx_num=len(args.ctx),
        in_seq_len=cfg.MODEL.IN_LEN,
        out_seq_len=cfg.MODEL.OUT_LEN)
    model_nowcasting_online = RadarNowcastingFactory(
        batch_size=1,
        in_seq_len=cfg.MODEL.IN_LEN,
        out_seq_len=cfg.MODEL.OUT_LEN)
    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=model_nowcasting,
            context=args.ctx)
    t_encoder_net, t_forecaster_net, t_loss_net = \
        encoder_forecaster_build_networks(
            factory=model_nowcasting_online,
            context=args.ctx[0],
            shared_encoder_net=encoder_net,
            shared_forecaster_net=forecaster_net,
            shared_loss_net=loss_net,
            for_finetune=True)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # Begin to load the model if load_dir is not empty
    if len(cfg.MODEL.LOAD_DIR) > 0:
        load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR,
                                       load_iter=cfg.MODEL.LOAD_ITER,
                                       encoder_net=encoder_net,
                                       forecaster_net=forecaster_net)
    states = EncoderForecasterStates(factory=model_nowcasting, ctx=args.ctx[0])
    for info in model_nowcasting.init_encoder_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    for info in model_nowcasting.init_forecaster_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    test_mode = "online" if cfg.MODEL.TRAIN.TBPTT else "fixed"
    iter_id = 0
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        # sample a random minibatch
        try:
            frame_dat, _, mask_dat = next(train_model_iter)
        except StopIteration:
            break
        else:
            states.reset_all()
            data_nd = mx.nd.array(frame_dat[0:cfg.MODEL.IN_LEN, ...],
                                  ctx=args.ctx[0])
            target_nd = mx.nd.array(
                frame_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN +
                                            cfg.MODEL.OUT_LEN), ...],
                ctx=args.ctx[0])
            mask_nd = mx.nd.array(
                mask_dat[cfg.MODEL.IN_LEN:(cfg.MODEL.IN_LEN +
                                           cfg.MODEL.OUT_LEN), ...],
                ctx=args.ctx[0])
            states, _ = train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                   encoder_net=encoder_net,
                                   forecaster_net=forecaster_net,
                                   loss_net=loss_net,
                                   init_states=states,
                                   data_nd=data_nd,
                                   gt_nd=target_nd,
                                   mask_nd=mask_nd,
                                   iter_id=iter_id)
            if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:
                encoder_net.save_checkpoint(prefix=os.path.join(
                    base_dir, "encoder_net"),
                                            epoch=iter_id)
                forecaster_net.save_checkpoint(prefix=os.path.join(
                    base_dir, "forecaster_net"),
                                               epoch=iter_id)
            if (iter_id + 1) % cfg.MODEL.VALID_ITER == 0:
                test_model_iter = HDFIterator(
                    all_data,
                    test_meta,
                    outlier_mask,
                    batch_size=1,
                    shuffle=False,
                    filter_threshold=cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD
                )
                run_benchmark(model_factory=model_nowcasting_online,
                              context=args.ctx[0],
                              encoder_net=t_encoder_net,
                              forecaster_net=t_forecaster_net,
                              save_dir=os.path.join(
                                  base_dir, "iter%d_valid" % (iter_id + 1)),
                              mode=test_mode,
                              batcher=test_model_iter)
            iter_id += 1
Exemple #8
0
def train(args):
    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
    base_dir = args.save_dir
    logging_config(folder=base_dir, name="training")
    save_movingmnist_cfg(base_dir)
    mnist_iter = MovingMNISTAdvancedIterator(
        distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM,
        initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER,
                                cfg.MOVINGMNIST.VELOCITY_UPPER),
        rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER,
                              cfg.MOVINGMNIST.ROTATION_UPPER),
        scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER,
                               cfg.MOVINGMNIST.SCALE_VARIATION_UPPER),
        illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER,
                                   cfg.MOVINGMNIST.ILLUMINATION_UPPER))

    mnist_rnn = MovingMNISTFactory(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE //
                                   len(args.ctx),
                                   in_seq_len=cfg.MODEL.IN_LEN,
                                   out_seq_len=cfg.MODEL.OUT_LEN)

    encoder_net, forecaster_net, loss_net = \
        encoder_forecaster_build_networks(
            factory=mnist_rnn,
            context=args.ctx)
    t_encoder_net, t_forecaster_net, t_loss_net = \
        encoder_forecaster_build_networks(
            factory=mnist_rnn,
            context=args.ctx[0],
            shared_encoder_net=encoder_net,
            shared_forecaster_net=forecaster_net,
            shared_loss_net=loss_net,
            for_finetune=True)
    encoder_net.summary()
    forecaster_net.summary()
    loss_net.summary()
    # Resume last checkpoint
    if args.resume:
        encoder_net.load(prefix=os.path.join(base_dir, 'encoder_net'),
                         epoch=latest_iter_id(base_dir),
                         load_optimizer_states=True,
                         data_names=[
                             'data', 'ebrnn1_begin_state_h',
                             'ebrnn2_begin_state_h', 'ebrnn3_begin_state_h'
                         ],
                         label_names=[])  # change it next time
        forecaster_net.load(prefix=os.path.join(base_dir, 'forecaster_net'),
                            epoch=latest_iter_id(base_dir),
                            load_optimizer_states=True,
                            data_names=[
                                'fbrnn1_begin_state_h', 'fbrnn2_begin_state_h',
                                'fbrnn3_begin_state_h'
                            ],
                            label_names=[])  # change it next time
    # Begin to load the model if load_dir is not empty
    if len(cfg.MODEL.LOAD_DIR) > 0:
        load_mnist_params(load_dir=cfg.MODEL.LOAD_DIR,
                          load_iter=cfg.MODEL.LOAD_ITER,
                          encoder_net=encoder_net,
                          forecaster_net=forecaster_net)
    states = EncoderForecasterStates(factory=mnist_rnn, ctx=args.ctx[0])
    states.reset_all()
    for info in mnist_rnn.init_encoder_state_info:
        assert info["__layout__"].find(
            'N') == 0, "Layout=%s is not supported!" % info["__layout__"]
    iter_id = 0
    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
        frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                                         seqlen=cfg.MOVINGMNIST.IN_LEN +
                                         cfg.MOVINGMNIST.OUT_LEN)
        data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...],
                              ctx=args.ctx[0]) / 255.0
        target_nd = mx.nd.array(frame_dat[cfg.MODEL.IN_LEN:(
            cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
                                ctx=args.ctx[0]) / 255.0
        train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                   encoder_net=encoder_net,
                   forecaster_net=forecaster_net,
                   loss_net=loss_net,
                   init_states=states,
                   data_nd=data_nd,
                   gt_nd=target_nd,
                   mask_nd=None,
                   iter_id=iter_id)
        if (iter_id + 1) % 100 == 0:
            new_frame_dat, _ = mnist_iter.sample(
                batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
                seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN)
            data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...],
                                  ctx=args.ctx[0]) / 255.0
            target_nd = mx.nd.array(frame_dat[cfg.MOVINGMNIST.IN_LEN:(
                cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
                                    ctx=args.ctx[0]) / 255.0
            pred_nd = mnist_get_prediction(data_nd=data_nd,
                                           states=states,
                                           encoder_net=encoder_net,
                                           forecaster_net=forecaster_net)
            save_gif(pred_nd.asnumpy()[:, 0, 0, :, :],
                     os.path.join(base_dir, "pred.gif"))
            save_gif(data_nd.asnumpy()[:, 0, 0, :, :],
                     os.path.join(base_dir, "in.gif"))
            save_gif(target_nd.asnumpy()[:, 0, 0, :, :],
                     os.path.join(base_dir, "gt.gif"))
        if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:
            encoder_net.save_checkpoint(prefix=os.path.join(
                base_dir, "encoder_net"),
                                        epoch=iter_id,
                                        save_optimizer_states=True)
            forecaster_net.save_checkpoint(prefix=os.path.join(
                base_dir, "forecaster_net"),
                                           epoch=iter_id,
                                           save_optimizer_states=True)
        iter_id += 1