def get_dims(self, y=None, u_state=None, u_obs=None, n_timesteps=None, n_batch=None): if y is not None: n_timesteps = y.shape[0] n_batch = y.shape[1] elif u_state is not None: n_timesteps = u_state.shape[0] n_batch = u_state.shape[1] elif u_obs is not None: n_timesteps = u_obs.shape[0] n_batch = u_obs.shape[1] else: if n_timesteps is None and n_batch is None: raise Exception( "either provide n_timesteps and n_batch directly, " "or provide any of (y, u_state, u_obs, u_switch). " f"Got following types: " f"y: {type(y)}, " f"u_state: {type(u_state)}, " f"u_obs: {type(u_obs)}, " f"n_timesteps: {type(n_timesteps)}, " f"n_batch: {type(n_batch)}") return TensorDims( timesteps=n_timesteps, particle=self.n_particle, batch=n_batch, state=self.n_state, target=self.n_target, ctrl_target=self.n_ctrl_target, ctrl_state=self.n_ctrl_state, )
n_hidden_rnn: int reconstruction_weight: float rao_blackwellized: bool dims_img = (1, 32, 32) config = PymunkKVAEConfig( dataset_name=consts.Datasets.box, experiment_name="kvae", dims_img=dims_img, dims=TensorDims( timesteps=20, particle=1, # KVAE does not use SMC batch=34, # pytorch-magma bug with 32. state=10, target=int(np.prod(dims_img)), switch=None, # --> n_hidden_rnn auxiliary=2, ctrl_target=None, ctrl_state=None, ), # batch_size_eval=4, num_samples_eval=256, prediction_length=40, n_epochs=400, lr=2e-3, lr_decay_rate=0.85, lr_decay_steps=20, grad_clip_norm=1500.0, weight_decay=0,
def make_default_config(dataset_name): timefeat = TimeFeatType.timefeat ( n_timefeat, n_staticfeat, n_latent, freq, cardinalities, prediction_length_rolling, prediction_length_full, ) = get_n_feat_and_freq(dataset_name=dataset_name, timefeat=timefeat) assert len(cardinalities["cardinalities_feat_static_cat"]) == 1 n_static_embedding = min( 50, (cardinalities["cardinalities_feat_static_cat"][0] + 1) // 2 ) n_ctrl_all = n_ctrl_static = n_ctrl_dynamic = 64 # n_ctrl_static = n_static_embedding # n_ctrl_dynamic = 32 # n_ctrl_all = n_ctrl_static + n_ctrl_dynamic # we cat dims = TensorDims( timesteps=past_lengths[dataset_name], particle=10, batch=50, state=n_latent, target=1, switch=5, # ctrl_state=None, # ctrl_switch=n_staticfeat + n_timefeat, # ctrl_obs=n_staticfeat + n_timefeat, ctrl_state=n_ctrl_dynamic, ctrl_target=n_ctrl_static, ctrl_switch=n_ctrl_all, # switch takes cat feats ctrl_encoder=n_ctrl_all, # encoder takes cat feats timefeat=n_timefeat, staticfeat=n_staticfeat, cat_embedding=n_static_embedding, auxiliary=None, ) config = RsglsIssmGtsExpConfig( experiment_name="rsgls", dataset_name=dataset_name, # n_epochs=50, n_epochs_no_resampling=5, n_epochs_freeze_gls_params=1, n_epochs_until_validate_loss=1, lr=5e-3, weight_decay=1e-5, grad_clip_norm=10.0, num_samples_eval=100, batch_size_val=100, # 10 # gpus=tuple(range(3, 4)), # dtype=torch.float64, # architecture, prior, etc. state_prior_scale=1.0, state_prior_loc=0.0, make_cov_from_cholesky_avg=True, extract_tail_chunks_for_train=False, switch_link_type=SwitchLinkType.individual, switch_link_dims_hidden=(64,), switch_link_activations=nn.LeakyReLU(0.1, inplace=True), recurrent_link_type=SwitchLinkType.individual, is_recurrent=True, n_base_A=20, n_base_B=20, n_base_C=20, n_base_D=20, n_base_Q=20, n_base_R=20, n_base_F=20, n_base_S=20, requires_grad_R=True, requires_grad_Q=True, requires_grad_S=True, # obs_to_switch_encoder=True, # state_to_switch_encoder=False, switch_prior_model_dims=tuple(), # TODO: made assumption that this is used for ctrl_state... input_transform_dims=(64,) + (dims.ctrl_state,), switch_transition_model_dims=(64,), # state_to_switch_encoder_dims=(64,), obs_to_switch_encoder_dims=(64,), b_fn_dims=tuple(), d_fn_dims=tuple(), # (64,), switch_prior_model_activations=LeakyReLU(0.1, inplace=True), input_transform_activations=LeakyReLU(0.1, inplace=True), switch_transition_model_activations=LeakyReLU(0.1, inplace=True), # state_to_switch_encoder_activations=LeakyReLU(0.1, inplace=True), obs_to_switch_encoder_activations=LeakyReLU(0.1, inplace=True), b_fn_activations=LeakyReLU(0.1, inplace=True), d_fn_activations=LeakyReLU(0.1, inplace=True), # initialisation init_scale_A=0.95, init_scale_B=0.0, init_scale_C=None, init_scale_D=0.0, init_scale_R_diag=[1e-5, 1e-1], init_scale_Q_diag=[1e-4, 1e0], init_scale_S_diag=[1e-5, 1e-1], # set from outside due to dependencies. dims=dims, freq=freq, time_feat=timefeat, add_trend=add_trend_map[dataset_name], prediction_length_rolling=prediction_length_rolling, prediction_length_full=prediction_length_full, normalisation_params=normalisation_params[dataset_name], LRinv_logdiag_scaling=1.0, LQinv_logdiag_scaling=1.0, A_scaling=1.0, B_scaling=1.0, C_scaling=1.0, D_scaling=1.0, LSinv_logdiag_scaling=1.0, F_scaling=1.0, eye_init_A=True, ) return config
def make_default_config(dataset_name): timefeat = TimeFeatType.timefeat ( n_timefeat, n_staticfeat, n_latent, freq, cardinalities, prediction_length_rolling, prediction_length_full, ) = get_n_feat_and_freq(dataset_name=dataset_name, timefeat=timefeat) assert len(cardinalities["cardinalities_feat_static_cat"]) == 1 n_static_embedding = min( 50, (cardinalities["cardinalities_feat_static_cat"][0] + 1) // 2) n_ctrl_all = n_ctrl_static = n_ctrl_dynamic = 64 # n_ctrl_static = n_static_embedding # n_ctrl_dynamic = 32 # n_ctrl_all = n_ctrl_static + n_ctrl_dynamic # we cat dims = TensorDims( timesteps=past_lengths[dataset_name], particle=1, batch=50, state=16, # n_latent, target=1, switch=None, ctrl_state=n_ctrl_dynamic, ctrl_target=n_ctrl_static, ctrl_switch=n_ctrl_all, ctrl_encoder=None, # KVAE uses pseudo-obs only, no controls for enc. timefeat=n_timefeat, staticfeat=n_staticfeat, cat_embedding=n_static_embedding, auxiliary=5, ) config = KvaeGtsExpConfig( experiment_name="kvae", dataset_name=dataset_name, # n_epochs=50, n_epochs_until_validate_loss=1, lr=5e-3, weight_decay=1e-5, grad_clip_norm=10.0, num_samples_eval=100, # Note: These batch sizes barely fit on the GPU. batch_size_val=10 if dataset_name in ["exchange_rate_nips", "wiki2000_nips", "wiki2000_nips"] else 2, # architecture, prior, etc. state_prior_scale=1.0, state_prior_loc=0.0, make_cov_from_cholesky_avg=True, extract_tail_chunks_for_train=False, switch_link_type=SwitchLinkType.shared, switch_link_dims_hidden=tuple(), # linear used in KVAE LSTM -> alpha switch_link_activations=tuple(), # they have 1 Dense layer after LSTM. recurrent_link_type=SwitchLinkType.shared, n_hidden_rnn=50, rao_blackwellized=True, reconstruction_weight=1.0, # They use 0.3 w/o rao-BW. dims_encoder=(64, 64), dims_decoder=(64, 64), activations_encoder=LeakyReLU(0.1, inplace=True), activations_decoder=LeakyReLU(0.1, inplace=True), n_base_A=20, n_base_B=20, n_base_C=20, n_base_D=None, # KVAE does not have D n_base_Q=20, n_base_R=20, n_base_F=None, n_base_S=None, requires_grad_R=True, requires_grad_Q=True, requires_grad_S=None, input_transform_dims=tuple() + (dims.ctrl_state, ), input_transform_activations=LeakyReLU(0.1, inplace=True), # initialisation init_scale_A=0.95, init_scale_B=0.0, init_scale_C=None, init_scale_D=None, init_scale_R_diag=[1e-4, 1e-1], init_scale_Q_diag=[1e-4, 1e-1], init_scale_S_diag=None, # init_scale_S_diag=[1e-5, 1e0], # set from outside due to dependencies. dims=dims, freq=freq, time_feat=timefeat, add_trend=add_trend_map[dataset_name], prediction_length_rolling=prediction_length_rolling, prediction_length_full=prediction_length_full, normalisation_params=normalisation_params[dataset_name], LRinv_logdiag_scaling=1.0, LQinv_logdiag_scaling=1.0, A_scaling=1.0, B_scaling=1.0, C_scaling=1.0, D_scaling=1.0, LSinv_logdiag_scaling=1.0, F_scaling=1.0, eye_init_A=True, ) return config
switch_prior_scale: float switch_prior_loc: float requires_grad_switch_prior: bool dims_img = (1, 32, 32) config = PymunkASGLSConfig( dataset_name=consts.Datasets.box, experiment_name="arsgls", dims_img=dims_img, dims=TensorDims( timesteps=20, particle=32, batch=34, # pytorch-magma bug with 32. state=10, target=int(np.prod(dims_img)), switch=8, auxiliary=2, ctrl_target=None, ctrl_state=None, ), # batch_size_eval=8, num_samples_eval=256, prediction_length=40, n_epochs=400, lr=2e-3, lr_decay_rate=0.85, lr_decay_steps=20, grad_clip_norm=1500.0, weight_decay=0,
b_fn_dims: tuple b_fn_activations: (nn.Module, tuple) d_fn_dims: tuple d_fn_activations: (nn.Module, tuple) n_epochs_no_resampling: int weight_decay: float grad_clip_norm: float dims = TensorDims( timesteps=50, particle=64, batch=100, state=3, target=2, switch=5, auxiliary=None, ctrl_target=None, ctrl_state=None, ) config = PendulumSGLSConfig( dataset_name=consts.Datasets.pendulum_3D_coord, experiment_name="default", batch_size_eval=1000, num_samples_eval=100, lr=1e-2, weight_decay=1e-5, grad_clip_norm=500.0, n_epochs=100, n_epochs_no_resampling=10,