Exemplo n.º 1
0
 def get_dims(self,
              y=None,
              u_state=None,
              u_obs=None,
              n_timesteps=None,
              n_batch=None):
     if y is not None:
         n_timesteps = y.shape[0]
         n_batch = y.shape[1]
     elif u_state is not None:
         n_timesteps = u_state.shape[0]
         n_batch = u_state.shape[1]
     elif u_obs is not None:
         n_timesteps = u_obs.shape[0]
         n_batch = u_obs.shape[1]
     else:
         if n_timesteps is None and n_batch is None:
             raise Exception(
                 "either provide n_timesteps and n_batch directly, "
                 "or provide any of (y, u_state, u_obs, u_switch). "
                 f"Got following types: "
                 f"y: {type(y)}, "
                 f"u_state: {type(u_state)}, "
                 f"u_obs: {type(u_obs)}, "
                 f"n_timesteps: {type(n_timesteps)}, "
                 f"n_batch: {type(n_batch)}")
     return TensorDims(
         timesteps=n_timesteps,
         particle=self.n_particle,
         batch=n_batch,
         state=self.n_state,
         target=self.n_target,
         ctrl_target=self.n_ctrl_target,
         ctrl_state=self.n_ctrl_state,
     )
Exemplo n.º 2
0
    n_hidden_rnn: int
    reconstruction_weight: float
    rao_blackwellized: bool


dims_img = (1, 32, 32)
config = PymunkKVAEConfig(
    dataset_name=consts.Datasets.box,
    experiment_name="kvae",
    dims_img=dims_img,
    dims=TensorDims(
        timesteps=20,
        particle=1,  # KVAE does not use SMC
        batch=34,  # pytorch-magma bug with 32.
        state=10,
        target=int(np.prod(dims_img)),
        switch=None,  # --> n_hidden_rnn
        auxiliary=2,
        ctrl_target=None,
        ctrl_state=None,
    ),
    #
    batch_size_eval=4,
    num_samples_eval=256,
    prediction_length=40,
    n_epochs=400,
    lr=2e-3,
    lr_decay_rate=0.85,
    lr_decay_steps=20,
    grad_clip_norm=1500.0,
    weight_decay=0,
Exemplo n.º 3
0
def make_default_config(dataset_name):
    timefeat = TimeFeatType.timefeat
    (
        n_timefeat,
        n_staticfeat,
        n_latent,
        freq,
        cardinalities,
        prediction_length_rolling,
        prediction_length_full,
    ) = get_n_feat_and_freq(dataset_name=dataset_name, timefeat=timefeat)
    assert len(cardinalities["cardinalities_feat_static_cat"]) == 1
    n_static_embedding = min(
        50, (cardinalities["cardinalities_feat_static_cat"][0] + 1) // 2
    )
    n_ctrl_all = n_ctrl_static = n_ctrl_dynamic = 64

    # n_ctrl_static = n_static_embedding
    # n_ctrl_dynamic = 32
    # n_ctrl_all = n_ctrl_static + n_ctrl_dynamic  # we cat

    dims = TensorDims(
        timesteps=past_lengths[dataset_name],
        particle=10,
        batch=50,
        state=n_latent,
        target=1,
        switch=5,
        # ctrl_state=None,
        # ctrl_switch=n_staticfeat + n_timefeat,
        # ctrl_obs=n_staticfeat + n_timefeat,
        ctrl_state=n_ctrl_dynamic,
        ctrl_target=n_ctrl_static,
        ctrl_switch=n_ctrl_all,  # switch takes cat feats
        ctrl_encoder=n_ctrl_all,  # encoder takes cat feats
        timefeat=n_timefeat,
        staticfeat=n_staticfeat,
        cat_embedding=n_static_embedding,
        auxiliary=None,
    )

    config = RsglsIssmGtsExpConfig(
        experiment_name="rsgls",
        dataset_name=dataset_name,
        #
        n_epochs=50,
        n_epochs_no_resampling=5,
        n_epochs_freeze_gls_params=1,
        n_epochs_until_validate_loss=1,
        lr=5e-3,
        weight_decay=1e-5,
        grad_clip_norm=10.0,
        num_samples_eval=100,
        batch_size_val=100,  # 10
        # gpus=tuple(range(3, 4)),
        # dtype=torch.float64,
        # architecture, prior, etc.
        state_prior_scale=1.0,
        state_prior_loc=0.0,
        make_cov_from_cholesky_avg=True,
        extract_tail_chunks_for_train=False,
        switch_link_type=SwitchLinkType.individual,
        switch_link_dims_hidden=(64,),
        switch_link_activations=nn.LeakyReLU(0.1, inplace=True),
        recurrent_link_type=SwitchLinkType.individual,
        is_recurrent=True,
        n_base_A=20,
        n_base_B=20,
        n_base_C=20,
        n_base_D=20,
        n_base_Q=20,
        n_base_R=20,
        n_base_F=20,
        n_base_S=20,
        requires_grad_R=True,
        requires_grad_Q=True,
        requires_grad_S=True,
        # obs_to_switch_encoder=True,
        # state_to_switch_encoder=False,
        switch_prior_model_dims=tuple(),
        # TODO: made assumption that this is used for ctrl_state...
        input_transform_dims=(64,) + (dims.ctrl_state,),
        switch_transition_model_dims=(64,),
        # state_to_switch_encoder_dims=(64,),
        obs_to_switch_encoder_dims=(64,),
        b_fn_dims=tuple(),
        d_fn_dims=tuple(),  # (64,),
        switch_prior_model_activations=LeakyReLU(0.1, inplace=True),
        input_transform_activations=LeakyReLU(0.1, inplace=True),
        switch_transition_model_activations=LeakyReLU(0.1, inplace=True),
        # state_to_switch_encoder_activations=LeakyReLU(0.1, inplace=True),
        obs_to_switch_encoder_activations=LeakyReLU(0.1, inplace=True),
        b_fn_activations=LeakyReLU(0.1, inplace=True),
        d_fn_activations=LeakyReLU(0.1, inplace=True),
        # initialisation
        init_scale_A=0.95,
        init_scale_B=0.0,
        init_scale_C=None,
        init_scale_D=0.0,
        init_scale_R_diag=[1e-5, 1e-1],
        init_scale_Q_diag=[1e-4, 1e0],
        init_scale_S_diag=[1e-5, 1e-1],
        # set from outside due to dependencies.
        dims=dims,
        freq=freq,
        time_feat=timefeat,
        add_trend=add_trend_map[dataset_name],
        prediction_length_rolling=prediction_length_rolling,
        prediction_length_full=prediction_length_full,
        normalisation_params=normalisation_params[dataset_name],
        LRinv_logdiag_scaling=1.0,
        LQinv_logdiag_scaling=1.0,
        A_scaling=1.0,
        B_scaling=1.0,
        C_scaling=1.0,
        D_scaling=1.0,
        LSinv_logdiag_scaling=1.0,
        F_scaling=1.0,
        eye_init_A=True,
    )
    return config
Exemplo n.º 4
0
def make_default_config(dataset_name):
    timefeat = TimeFeatType.timefeat
    (
        n_timefeat,
        n_staticfeat,
        n_latent,
        freq,
        cardinalities,
        prediction_length_rolling,
        prediction_length_full,
    ) = get_n_feat_and_freq(dataset_name=dataset_name, timefeat=timefeat)
    assert len(cardinalities["cardinalities_feat_static_cat"]) == 1
    n_static_embedding = min(
        50, (cardinalities["cardinalities_feat_static_cat"][0] + 1) // 2)
    n_ctrl_all = n_ctrl_static = n_ctrl_dynamic = 64

    # n_ctrl_static = n_static_embedding
    # n_ctrl_dynamic = 32
    # n_ctrl_all = n_ctrl_static + n_ctrl_dynamic  # we cat

    dims = TensorDims(
        timesteps=past_lengths[dataset_name],
        particle=1,
        batch=50,
        state=16,  # n_latent,
        target=1,
        switch=None,
        ctrl_state=n_ctrl_dynamic,
        ctrl_target=n_ctrl_static,
        ctrl_switch=n_ctrl_all,
        ctrl_encoder=None,  # KVAE uses pseudo-obs only, no controls for enc.
        timefeat=n_timefeat,
        staticfeat=n_staticfeat,
        cat_embedding=n_static_embedding,
        auxiliary=5,
    )

    config = KvaeGtsExpConfig(
        experiment_name="kvae",
        dataset_name=dataset_name,
        #
        n_epochs=50,
        n_epochs_until_validate_loss=1,
        lr=5e-3,
        weight_decay=1e-5,
        grad_clip_norm=10.0,
        num_samples_eval=100,
        # Note: These batch sizes barely fit on the GPU.
        batch_size_val=10 if dataset_name
        in ["exchange_rate_nips", "wiki2000_nips", "wiki2000_nips"] else 2,
        # architecture, prior, etc.
        state_prior_scale=1.0,
        state_prior_loc=0.0,
        make_cov_from_cholesky_avg=True,
        extract_tail_chunks_for_train=False,
        switch_link_type=SwitchLinkType.shared,
        switch_link_dims_hidden=tuple(),  # linear used in KVAE LSTM -> alpha
        switch_link_activations=tuple(),
        # they have 1 Dense layer after LSTM.
        recurrent_link_type=SwitchLinkType.shared,
        n_hidden_rnn=50,
        rao_blackwellized=True,
        reconstruction_weight=1.0,  # They use 0.3 w/o rao-BW.
        dims_encoder=(64, 64),
        dims_decoder=(64, 64),
        activations_encoder=LeakyReLU(0.1, inplace=True),
        activations_decoder=LeakyReLU(0.1, inplace=True),
        n_base_A=20,
        n_base_B=20,
        n_base_C=20,
        n_base_D=None,  # KVAE does not have D
        n_base_Q=20,
        n_base_R=20,
        n_base_F=None,
        n_base_S=None,
        requires_grad_R=True,
        requires_grad_Q=True,
        requires_grad_S=None,
        input_transform_dims=tuple() + (dims.ctrl_state, ),
        input_transform_activations=LeakyReLU(0.1, inplace=True),
        # initialisation
        init_scale_A=0.95,
        init_scale_B=0.0,
        init_scale_C=None,
        init_scale_D=None,
        init_scale_R_diag=[1e-4, 1e-1],
        init_scale_Q_diag=[1e-4, 1e-1],
        init_scale_S_diag=None,
        # init_scale_S_diag=[1e-5, 1e0],
        # set from outside due to dependencies.
        dims=dims,
        freq=freq,
        time_feat=timefeat,
        add_trend=add_trend_map[dataset_name],
        prediction_length_rolling=prediction_length_rolling,
        prediction_length_full=prediction_length_full,
        normalisation_params=normalisation_params[dataset_name],
        LRinv_logdiag_scaling=1.0,
        LQinv_logdiag_scaling=1.0,
        A_scaling=1.0,
        B_scaling=1.0,
        C_scaling=1.0,
        D_scaling=1.0,
        LSinv_logdiag_scaling=1.0,
        F_scaling=1.0,
        eye_init_A=True,
    )
    return config
Exemplo n.º 5
0
    switch_prior_scale: float
    switch_prior_loc: float
    requires_grad_switch_prior: bool


dims_img = (1, 32, 32)
config = PymunkASGLSConfig(
    dataset_name=consts.Datasets.box,
    experiment_name="arsgls",
    dims_img=dims_img,
    dims=TensorDims(
        timesteps=20,
        particle=32,
        batch=34,  # pytorch-magma bug with 32.
        state=10,
        target=int(np.prod(dims_img)),
        switch=8,
        auxiliary=2,
        ctrl_target=None,
        ctrl_state=None,
    ),
    #
    batch_size_eval=8,
    num_samples_eval=256,
    prediction_length=40,
    n_epochs=400,
    lr=2e-3,
    lr_decay_rate=0.85,
    lr_decay_steps=20,
    grad_clip_norm=1500.0,
    weight_decay=0,
Exemplo n.º 6
0
    b_fn_dims: tuple
    b_fn_activations: (nn.Module, tuple)
    d_fn_dims: tuple
    d_fn_activations: (nn.Module, tuple)

    n_epochs_no_resampling: int
    weight_decay: float
    grad_clip_norm: float


dims = TensorDims(
    timesteps=50,
    particle=64,
    batch=100,
    state=3,
    target=2,
    switch=5,
    auxiliary=None,
    ctrl_target=None,
    ctrl_state=None,
)
config = PendulumSGLSConfig(
    dataset_name=consts.Datasets.pendulum_3D_coord,
    experiment_name="default",
    batch_size_eval=1000,
    num_samples_eval=100,
    lr=1e-2,
    weight_decay=1e-5,
    grad_clip_norm=500.0,
    n_epochs=100,
    n_epochs_no_resampling=10,