예제 #1
0
    def __init__(self, config):
        self.num_hierarchies = 2

        dim_in = config.dims.target + config.dims.ctrl_encoder
        dim_out_1 = config.dims.auxiliary
        dim_out_2 = config.dims.switch

        dims_stems = config.dims_encoder
        activations_stems = config.activations_encoders
        dim_in_stem_2 = dims_stems[0][-1] if len(dims_stems[0]) > 0 else dim_in
        dim_in_dist_params_1 = (dims_stems[0][-1]
                                if len(dims_stems[0]) > 0 else dim_in)
        dim_in_dist_params_2 = (dims_stems[1][-1]
                                if len(dims_stems[1]) > 0 else dim_in_stem_2)

        super().__init__(
            allow_cat_inputs=True,
            stem=nn.ModuleList([
                MLP(
                    dim_in=dim_in,
                    dims=dims_stems[0],
                    activations=activations_stems[0],
                ),
                MLP(
                    dim_in=dim_in_stem_2,
                    dims=dims_stems[1],
                    activations=activations_stems[1],
                ),
            ]),
            dist_params=nn.ModuleList([
                nn.ModuleDict({
                    "loc":
                    nn.Sequential(
                        Linear(
                            in_features=dim_in_dist_params_1,
                            out_features=dim_out_1,
                        ), ),
                    "scale_tril":
                    DefaultScaleTransform(
                        dim_in_dist_params_1,
                        dim_out_1,
                    ),
                }),
                nn.ModuleDict({
                    "loc":
                    nn.Sequential(
                        Linear(
                            in_features=dim_in_dist_params_2,
                            out_features=dim_out_2,
                        ), ),
                    "scale_tril":
                    DefaultScaleTransform(
                        dim_in_dist_params_2,
                        dim_out_2,
                    ),
                }),
            ]),
            dist_cls=[MultivariateNormal, MultivariateNormal],
        )
예제 #2
0
    def __init__(self, config):
        super().__init__()
        (
            dim_in,
            dim_out,
            dims_stem,
            activations_stem,
            dim_in_dist_params,
        ) = _extract_dims_from_cfg(config)

        self.conditional_dist = ParametrisedConditionalDistribution(
            stem=MLP(
                dim_in=dim_in, dims=dims_stem, activations=activations_stem,
            ),
            dist_params=nn.ModuleDict(
                {
                    "loc": nn.Sequential(Linear(dim_in_dist_params, dim_out),),
                    "scale_tril": DefaultScaleTransform(
                        dim_in_dist_params,
                        dim_out,
                    ),
                }
            ),
            dist_cls=MultivariateNormal,
        )
예제 #3
0
 def __init__(self, config):
     (
         dim_in,
         dim_out,
         dims_stem,
         activations_stem,
         dim_in_dist_params,
     ) = _extract_dims_from_cfg_obs(config=config)
     super().__init__(
         allow_cat_inputs=True,
         stem=MLP(
             dim_in=dim_in,
             dims=dims_stem,
             activations=activations_stem,
         ),
         dist_params=nn.ModuleDict({
             "loc":
             nn.Sequential(
                 Linear(
                     in_features=dim_in_dist_params,
                     out_features=dim_out,
                 ), ),
             "scale_tril":
             DefaultScaleTransform(
                 dim_in_dist_params,
                 dim_out,
             ),
         }),
         dist_cls=MultivariateNormal,
     )
예제 #4
0
 def __init__(self, config):
     super().__init__()
     (
         dim_in,
         dim_out,
         dims_stem,
         activations_stem,
         dim_in_dist_params,
     ) = _extract_dims_from_cfg(config=config)
     if dim_in is None:
         self.dist = ParametrisedOneHotCategorical(
             logits=torch.zeros(dim_out), requires_grad=True,
         )
     else:
         self.dist = ParametrisedConditionalDistribution(
             stem=MLP(
                 dim_in=dim_in,
                 dims=dims_stem,
                 activations=activations_stem,
             ),
             dist_params=nn.ModuleDict(
                 {
                     "logits": nn.Sequential(
                         Linear(
                             in_features=dim_in_dist_params,
                             out_features=dim_out,
                         ),
                         # SigmoidLimiter(limits=[-10, 10]),
                     )
                 }
             ),
             dist_cls=OneHotCategorical,
         )
예제 #5
0
    def __init__(self, config):
        super().__init__()
        (
            dim_in,
            dim_out,
            dims_stem,
            activations_stem,
            dim_in_dist_params,
        ) = _extract_dims_from_cfg(config)

        self.conditional_dist = ParametrisedConditionalDistribution(
            stem=MLP(
                dim_in=dim_in, dims=dims_stem, activations=activations_stem,
            ),
            dist_params=nn.ModuleDict(
                {
                    "loc": nn.Sequential(Linear(dim_in_dist_params, dim_out),),
                    "scale_tril": Constant(
                        val=0,
                        shp_append=(dim_out, dim_out),
                        n_dims_from_input=-1,  # x.shape[:-1]
                    ),
                }
            ),
            dist_cls=MultivariateNormal,
        )
예제 #6
0
    def __init__(self, config):
        dim_in = config.dims.target
        dim_out = config.dims.auxiliary
        dims_stem = config.dims_encoder
        activations_stem = config.activations_decoder
        dim_in_dist_params = dims_stem[-1] if len(dims_stem) > 0 else dim_in

        super().__init__(
            allow_cat_inputs=True,
            stem=MLP(
                dim_in=dim_in,
                dims=dims_stem,
                activations=activations_stem,
            ),
            dist_params=nn.ModuleDict({
                "loc":
                nn.Sequential(
                    Linear(
                        in_features=dim_in_dist_params,
                        out_features=dim_out,
                    ), ),
                "scale_tril":
                DefaultScaleTransform(
                    dim_in_dist_params,
                    dim_out,
                ),
            }),
            dist_cls=MultivariateNormal,
        )
예제 #7
0
 def __init__(self, config):
     (
         dim_in,
         dim_out,
         dims_stem,
         activations_stem,
         dim_in_dist_params,
     ) = _extract_dims_from_cfg_obs(config=config)
     super().__init__(
         allow_cat_inputs=True,
         stem=MLP(
             dim_in=dim_in,
             dims=dims_stem,
             activations=activations_stem,
         ),
         dist_params=nn.ModuleDict({
             "logits":
             nn.Sequential(
                 Linear(
                     in_features=dim_in_dist_params,
                     out_features=dim_out,
                 ),
                 SigmoidLimiter(limits=[-5, 5]),
             )
         }),
         dist_cls=OneHotCategorical,
     )
예제 #8
0
 def __init__(self, config):
     super().__init__()
     self.mlp = MLP(
         dim_in=config.dims.timefeat + config.dims.staticfeat,
         dims=config.input_transform_dims,
         activations=config.input_transform_activations,
     )
예제 #9
0
 def __init__(self, config):
     super().__init__(
         n_state=config.dims.state,
         n_obs=config.dims.auxiliary,
         # NOTE: SSM (pseudo) obs are Model auxiliary.
         n_switch=config.dims.switch,
         n_ctrl_state=config.dims.ctrl_state,
         n_ctrl_obs=config.dims.ctrl_target,
         n_base_A=config.n_base_A,
         n_base_B=config.n_base_B,
         n_base_C=config.n_base_C,
         n_base_D=config.n_base_D,
         n_base_Q=config.n_base_Q,
         n_base_R=config.n_base_R,
         switch_link_type=config.switch_link_type,
         switch_link_dims_hidden=config.switch_link_dims_hidden,
         switch_link_activations=config.switch_link_activations,
         b_fn=MLP(
             dim_in=config.dims.switch,  # f: b(s)
             dims=tuple(config.b_fn_dims) + (config.dims.state, ),
             activations=config.b_fn_activations,
         ) if config.b_fn_dims else None,
         d_fn=MLP(
             dim_in=config.dims.switch,  # f: d(s)
             dims=tuple(config.d_fn_dims) + (config.dims.target, ),
             activations=config.d_fn_activations,
         ) if config.d_fn_dims else None,
         init_scale_A=config.init_scale_A,
         init_scale_B=config.init_scale_B,
         init_scale_C=config.init_scale_C,
         init_scale_D=config.init_scale_D,
         init_scale_R_diag=config.init_scale_R_diag,
         init_scale_Q_diag=config.init_scale_Q_diag,
         full_cov_R=False,
         full_cov_Q=False,
         requires_grad_R=config.requires_grad_R,
         requires_grad_Q=config.requires_grad_Q,
         LRinv_logdiag_scaling=config.LRinv_logdiag_scaling,
         LQinv_logdiag_scaling=config.LQinv_logdiag_scaling,
         A_scaling=config.A_scaling,
         B_scaling=config.B_scaling,
         C_scaling=config.C_scaling,
         D_scaling=config.D_scaling,
         eye_init_A=config.eye_init_A,
     )
예제 #10
0
 def __init__(self, config):
     super().__init__()
     self.embedding = nn.Embedding(
         num_embeddings=config.dims.staticfeat,
         embedding_dim=config.dims.cat_embedding,
     )
     self.mlp = MLP(
         dim_in=config.dims.cat_embedding + config.dims.timefeat,
         dims=config.input_transform_dims,
         activations=config.input_transform_activations,
     )
예제 #11
0
 def __init__(self, config):
     super().__init__(
         issm=CompositeISSM.get_from_freq(freq=config.freq,
                                          add_trend=config.add_trend),
         n_state=config.dims.state,
         n_obs=config.dims.target,
         n_ctrl_state=config.dims.ctrl_state,
         n_ctrl_obs=config.dims.ctrl_target,
         n_switch=config.dims.switch,
         n_base_B=config.n_base_B,
         n_base_D=config.n_base_D,
         n_base_R=config.n_base_R,
         n_base_Q=config.n_base_Q,
         switch_link_type=config.switch_link_type,
         switch_link_dims_hidden=config.switch_link_dims_hidden,
         switch_link_activations=config.switch_link_activations,
         make_cov_from_cholesky_avg=config.make_cov_from_cholesky_avg,
         b_fn=MLP(
             dim_in=config.dims.switch,  # f: b(s)
             dims=tuple(config.b_fn_dims) + (config.dims.state, ),
             activations=config.b_fn_activations,
         ) if config.b_fn_dims else None,
         d_fn=MLP(
             dim_in=config.dims.switch,  # f: d(s)
             dims=tuple(config.d_fn_dims) + (config.dims.target, ),
             activations=config.d_fn_activations,
         ) if config.d_fn_dims else None,
         init_scale_R_diag=config.init_scale_R_diag,
         init_scale_Q_diag=config.init_scale_Q_diag,
         init_scale_A=config.init_scale_A,
         requires_grad_R=True,
         requires_grad_Q=True,
         LRinv_logdiag_scaling=config.LRinv_logdiag_scaling,
         LQinv_logdiag_scaling=config.LQinv_logdiag_scaling,
         B_scaling=config.B_scaling,
         D_scaling=config.D_scaling,
         eye_init_A=config.eye_init_A,
     )
예제 #12
0
    def __init__(self, config):
        super().__init__()
        # static_transform_dims = config.input_transform_dims
        # static_transform_activations = config.input_transform_activations
        dynamic_transform_dims = config.input_transform_dims
        dynamic_transform_activations = config.input_transform_activations

        assert config.dims.cat_embedding == config.dims.ctrl_target

        self.embedding_static = nn.Embedding(
            num_embeddings=config.dims.staticfeat,
            embedding_dim=config.dims.cat_embedding,
        )
        self.mlp_dynamic = MLP(
            dim_in=config.dims.timefeat,
            dims=dynamic_transform_dims,
            activations=dynamic_transform_activations,
        )
예제 #13
0
 def __init__(
     self,
     dim_in,
     dim_out,
     names: (list, tuple),
     dims_hidden: Tuple[int] = tuple(),
     activations_hidden: nn.Module = nn.LeakyReLU(0.1, inplace=True),
     norm_type: Optional[NormalizationType] = NormalizationType.none,
 ):
     super().__init__()
     if isinstance(activations_hidden, nn.Module):
         activations_hidden = (activations_hidden, ) * len(dims_hidden)
     self.names = names
     self.link = MLP(
         dim_in=dim_in,
         dims=dims_hidden + (dim_out, ),
         activations=activations_hidden + (nn.Softmax(dim=-1), ),
         norm_type=norm_type,
     )
예제 #14
0
 def __init__(self, config):
     super().__init__()
     (
         dim_in,
         dim_out,
         dims_stem,
         activations_stem,
         dim_in_dist_params,
     ) = _extract_dims_from_cfg(config=config)
     if dim_in is None:
         covariance_matrix = (
             (config.switch_prior_scale ** 2) or 1.0
         ) * torch.eye(dim_out)
         LVinv_tril, LVinv_logdiag = make_inv_tril_parametrization(
             covariance_matrix
         )
         self.dist = ParametrisedMultivariateNormal(
             m=torch.ones(config.dims.switch) * config.switch_prior_loc,
             LVinv_tril=LVinv_tril,
             LVinv_logdiag=LVinv_logdiag,
             requires_grad_m=config.requires_grad_switch_prior,
             requires_diag_LVinv_tril=False,
             requires_diag_LVinv_logdiag=config.requires_grad_switch_prior,
         )
     else:
         self.dist = ParametrisedConditionalDistribution(
             stem=MLP(
                 dim_in=dim_in,
                 dims=dims_stem,
                 activations=activations_stem,
             ),
             dist_params=nn.ModuleDict(
                 {
                     "loc": nn.Sequential(
                         Linear(dim_in_dist_params, dim_out),
                     ),
                     "scale_tril": DefaultScaleTransform(
                         dim_in_dist_params, dim_out,
                     ),
                 }
             ),
             dist_cls=MultivariateNormal,
         )
예제 #15
0
 def __init__(
     self,
     dim_in,
     names_and_dims_out: Dict[str, int],
     dims_hidden: Tuple[int] = tuple(),
     activations_hidden: nn.Module = nn.LeakyReLU(0.1, inplace=True),
     norm_type: Optional[NormalizationType] = NormalizationType.none,
 ):
     super().__init__()
     if isinstance(activations_hidden, nn.Module):
         activations_hidden = (activations_hidden, ) * len(dims_hidden)
     for name, dim_out in names_and_dims_out.items():
         self.update({
             name:
             MLP(
                 dim_in=dim_in,
                 dims=dims_hidden + (dim_out, ),
                 activations=activations_hidden + (nn.Softmax(dim=-1), ),
                 norm_type=norm_type,
             )
         })
예제 #16
0
 def __init__(self, config):
     super().__init__()
     (
         dim_in,
         dim_out,
         dims_stem,
         activations_stem,
         dim_in_dist_params,
     ) = _extract_dims_from_cfg(config)
     dim_in_dist_params = dims_stem[-1] if len(dims_stem) > 0 else dim_in
     self.conditional_dist = ParametrisedConditionalDistribution(
         stem=MLP(
             dim_in=dim_in, dims=dims_stem, activations=activations_stem,
         ),
         dist_params=nn.ModuleDict(
             {
                 "logits": nn.Sequential(
                     Linear(dim_in_dist_params, dim_out),
                     SigmoidLimiter(limits=[-10, 10]),
                 )
             }
         ),
         dist_cls=OneHotCategorical,
     )
예제 #17
0
    def __init__(self, config):
        self.num_hierarchies = 2
        if config.dims.ctrl_encoder not in [None, 0]:
            raise ValueError(
                "no controls. would require different architecture "
                "or mixing with images.")
        shp_enc_out, dim_out_flat_conv = compute_cnn_output_filters_and_dims(
            dims_img=config.dims_img,
            dims_filter=config.dims_filter,
            kernel_sizes=config.kernel_sizes,
            strides=config.strides,
            paddings=config.paddings,
        )

        assert config.dims_encoder[0] is None, (
            "first stem is a conv net. "
            "config is given differently...")
        dims_stem_2 = (  # TODO: really past self?
            32,
            32,
        )
        activations_stem_2 = nn.ReLU()
        dim_out_1 = config.dims.auxiliary
        dim_out_2 = config.dims.switch
        dim_in_dist_params_1 = dim_out_flat_conv
        dim_in_dist_params_2 = (dims_stem_2[-1]
                                if len(dims_stem_2) > 0 else dim_out_flat_conv)

        super().__init__(
            allow_cat_inputs=False,  # images and scalar...
            stem=nn.ModuleList([
                nn.Sequential(
                    Reshape(
                        config.dims_img),  # TxPxB will be flattened before.
                    Conv2d(
                        in_channels=config.dims_img[0],
                        out_channels=config.dims_filter[0],
                        kernel_size=config.kernel_sizes[0],
                        stride=config.strides[0],
                        padding=config.paddings[0],
                    ),
                    nn.ReLU(),
                    Conv2d(
                        in_channels=config.dims_filter[0],
                        out_channels=config.dims_filter[1],
                        kernel_size=config.kernel_sizes[1],
                        stride=config.strides[1],
                        padding=config.paddings[1],
                    ),
                    nn.ReLU(),
                    Conv2d(
                        in_channels=config.dims_filter[1],
                        out_channels=config.dims_filter[2],
                        kernel_size=config.kernel_sizes[2],
                        stride=config.strides[2],
                        padding=config.paddings[2],
                    ),
                    nn.ReLU(),
                    Reshape((dim_out_flat_conv, )),  # Flatten image dims
                ),
                MLP(
                    dim_in=dim_out_flat_conv,
                    dims=dims_stem_2,
                    activations=activations_stem_2,
                ),
            ]),
            dist_params=nn.ModuleList([
                nn.ModuleDict({
                    "loc":
                    nn.Sequential(
                        Linear(
                            in_features=dim_in_dist_params_1,
                            out_features=dim_out_1,
                        ), ),
                    "scale_tril":
                    DefaultScaleTransform(
                        dim_in_dist_params_1,
                        dim_out_1,
                    ),
                }),
                nn.ModuleDict({
                    "loc":
                    nn.Sequential(
                        Linear(
                            in_features=dim_in_dist_params_2,
                            out_features=dim_out_2,
                        ), ),
                    "scale_tril":
                    DefaultScaleTransform(
                        dim_in_dist_params_2,
                        dim_out_2,
                    ),
                }),
            ]),
            dist_cls=[MultivariateNormal, MultivariateNormal],
        )