Ejemplo n.º 1
0
    def __init__(self, config):
        self.num_hierarchies = 2

        dim_in = config.dims.target + config.dims.ctrl_encoder
        dim_out_1 = config.dims.auxiliary
        dim_out_2 = config.dims.switch

        dims_stems = config.dims_encoder
        activations_stems = config.activations_encoders
        dim_in_stem_2 = dims_stems[0][-1] if len(dims_stems[0]) > 0 else dim_in
        dim_in_dist_params_1 = (dims_stems[0][-1]
                                if len(dims_stems[0]) > 0 else dim_in)
        dim_in_dist_params_2 = (dims_stems[1][-1]
                                if len(dims_stems[1]) > 0 else dim_in_stem_2)

        super().__init__(
            allow_cat_inputs=True,
            stem=nn.ModuleList([
                MLP(
                    dim_in=dim_in,
                    dims=dims_stems[0],
                    activations=activations_stems[0],
                ),
                MLP(
                    dim_in=dim_in_stem_2,
                    dims=dims_stems[1],
                    activations=activations_stems[1],
                ),
            ]),
            dist_params=nn.ModuleList([
                nn.ModuleDict({
                    "loc":
                    nn.Sequential(
                        Linear(
                            in_features=dim_in_dist_params_1,
                            out_features=dim_out_1,
                        ), ),
                    "scale_tril":
                    DefaultScaleTransform(
                        dim_in_dist_params_1,
                        dim_out_1,
                    ),
                }),
                nn.ModuleDict({
                    "loc":
                    nn.Sequential(
                        Linear(
                            in_features=dim_in_dist_params_2,
                            out_features=dim_out_2,
                        ), ),
                    "scale_tril":
                    DefaultScaleTransform(
                        dim_in_dist_params_2,
                        dim_out_2,
                    ),
                }),
            ]),
            dist_cls=[MultivariateNormal, MultivariateNormal],
        )
Ejemplo n.º 2
0
    def __init__(self, config):
        super().__init__()
        (
            dim_in,
            dim_out,
            dims_stem,
            activations_stem,
            dim_in_dist_params,
        ) = _extract_dims_from_cfg(config)

        self.conditional_dist = ParametrisedConditionalDistribution(
            stem=MLP(
                dim_in=dim_in, dims=dims_stem, activations=activations_stem,
            ),
            dist_params=nn.ModuleDict(
                {
                    "loc": nn.Sequential(Linear(dim_in_dist_params, dim_out),),
                    "scale_tril": DefaultScaleTransform(
                        dim_in_dist_params,
                        dim_out,
                    ),
                }
            ),
            dist_cls=MultivariateNormal,
        )
Ejemplo n.º 3
0
    def __init__(self, config):
        super().__init__()
        (
            dim_in,
            dim_out,
            dims_stem,
            activations_stem,
            dim_in_dist_params,
        ) = _extract_dims_from_cfg(config)

        self.conditional_dist = ParametrisedConditionalDistribution(
            stem=MLP(
                dim_in=dim_in, dims=dims_stem, activations=activations_stem,
            ),
            dist_params=nn.ModuleDict(
                {
                    "loc": nn.Sequential(Linear(dim_in_dist_params, dim_out),),
                    "scale_tril": Constant(
                        val=0,
                        shp_append=(dim_out, dim_out),
                        n_dims_from_input=-1,  # x.shape[:-1]
                    ),
                }
            ),
            dist_cls=MultivariateNormal,
        )
Ejemplo n.º 4
0
 def __init__(self, config):
     (
         dim_in,
         dim_out,
         dims_stem,
         activations_stem,
         dim_in_dist_params,
     ) = _extract_dims_from_cfg_obs(config=config)
     super().__init__(
         allow_cat_inputs=True,
         stem=MLP(
             dim_in=dim_in,
             dims=dims_stem,
             activations=activations_stem,
         ),
         dist_params=nn.ModuleDict({
             "loc":
             nn.Sequential(
                 Linear(
                     in_features=dim_in_dist_params,
                     out_features=dim_out,
                 ), ),
             "scale_tril":
             DefaultScaleTransform(
                 dim_in_dist_params,
                 dim_out,
             ),
         }),
         dist_cls=MultivariateNormal,
     )
Ejemplo n.º 5
0
 def __init__(self, config):
     (
         dim_in,
         dim_out,
         dims_stem,
         activations_stem,
         dim_in_dist_params,
     ) = _extract_dims_from_cfg_obs(config=config)
     super().__init__(
         allow_cat_inputs=True,
         stem=MLP(
             dim_in=dim_in,
             dims=dims_stem,
             activations=activations_stem,
         ),
         dist_params=nn.ModuleDict({
             "logits":
             nn.Sequential(
                 Linear(
                     in_features=dim_in_dist_params,
                     out_features=dim_out,
                 ),
                 SigmoidLimiter(limits=[-5, 5]),
             )
         }),
         dist_cls=OneHotCategorical,
     )
Ejemplo n.º 6
0
    def __init__(
        self,
        dim_in: int,
        dims: Tuple[int],
        activations: (Tuple[nn.Module], nn.Module, None),
        norm_type: Optional[NormalizationType] = None,
        # norm_type: Optional[NormalizationType] = NormalizationType.layer_learnable,
    ):
        super().__init__()
        assert isinstance(dims, (tuple, list))

        if not isinstance(activations, (tuple, list)):
            activations = tuple(activations for _ in range(len(dims)))

        dims_in = (dim_in, ) + tuple(dims[:-1])
        dims_out = dims
        for l, (n_in, n_out,
                activation) in enumerate(zip(dims_in, dims_out, activations)):
            self.add_module(name=f"linear_{l}", module=Linear(n_in, n_out))
            norm_layer = make_norm_layer(
                norm_type=norm_type,
                shp_features=[n_out],  # after Dense -> out
            )
            if norm_layer is not None:
                self.add_module(name=f"norm_{norm_type.name}_{l}",
                                module=norm_layer)
            if activation is not None:
                self.add_module(name=f"activation_{l}", module=activation)
Ejemplo n.º 7
0
    def __init__(self, config):
        dim_in = config.dims.target
        dim_out = config.dims.auxiliary
        dims_stem = config.dims_encoder
        activations_stem = config.activations_decoder
        dim_in_dist_params = dims_stem[-1] if len(dims_stem) > 0 else dim_in

        super().__init__(
            allow_cat_inputs=True,
            stem=MLP(
                dim_in=dim_in,
                dims=dims_stem,
                activations=activations_stem,
            ),
            dist_params=nn.ModuleDict({
                "loc":
                nn.Sequential(
                    Linear(
                        in_features=dim_in_dist_params,
                        out_features=dim_out,
                    ), ),
                "scale_tril":
                DefaultScaleTransform(
                    dim_in_dist_params,
                    dim_out,
                ),
            }),
            dist_cls=MultivariateNormal,
        )
Ejemplo n.º 8
0
 def __init__(self, config):
     super().__init__()
     (
         dim_in,
         dim_out,
         dims_stem,
         activations_stem,
         dim_in_dist_params,
     ) = _extract_dims_from_cfg(config=config)
     if dim_in is None:
         self.dist = ParametrisedOneHotCategorical(
             logits=torch.zeros(dim_out), requires_grad=True,
         )
     else:
         self.dist = ParametrisedConditionalDistribution(
             stem=MLP(
                 dim_in=dim_in,
                 dims=dims_stem,
                 activations=activations_stem,
             ),
             dist_params=nn.ModuleDict(
                 {
                     "logits": nn.Sequential(
                         Linear(
                             in_features=dim_in_dist_params,
                             out_features=dim_out,
                         ),
                         # SigmoidLimiter(limits=[-10, 10]),
                     )
                 }
             ),
             dist_cls=OneHotCategorical,
         )
Ejemplo n.º 9
0
 def __init__(self, config):
     shp_enc_out, dim_out_flat_conv = compute_cnn_output_filters_and_dims(
         dims_img=config.dims_img,
         dims_filter=config.dims_conv,
         kernel_sizes=config.kernel_sizes_conv,
         strides=config.kernel_sizes_conv,
         paddings=config.paddings_conv,
     )
     super().__init__(
         stem=nn.Sequential(
             Reshape(config.dims_img),  # TxPxB will be flattened before.
             Conv2d(
                 in_channels=config.dims_img[0],
                 out_channels=config.dims_conv[0],
                 kernel_size=config.kernel_sizes_conv[0],
                 stride=config.strides_conv[0],
                 padding=config.paddings_conv[0],
             ),
             nn.ReLU(),
             Conv2d(
                 in_channels=config.dims_conv[0],
                 out_channels=config.dims_conv[1],
                 kernel_size=config.kernel_sizes_conv[1],
                 stride=config.strides_conv[1],
                 padding=config.paddings_conv[1],
             ),
             nn.ReLU(),
             Conv2d(
                 in_channels=config.dims_conv[1],
                 out_channels=config.dims_conv[2],
                 kernel_size=config.kernel_sizes_conv[2],
                 stride=config.strides_conv[2],
                 padding=config.paddings_conv[2],
             ),
             nn.ReLU(),
             Reshape((dim_out_flat_conv, )),  # Flatten image dims
         ),
         dist_params=nn.ModuleDict({
             "logits":
             Linear(
                 in_features=dim_out_flat_conv,
                 out_features=config.dims.switch,
             ),
         }),
         dist_cls=OneHotCategorical,
     )
Ejemplo n.º 10
0
 def __init__(self, config):
     super().__init__()
     (
         dim_in,
         dim_out,
         dims_stem,
         activations_stem,
         dim_in_dist_params,
     ) = _extract_dims_from_cfg(config=config)
     if dim_in is None:
         covariance_matrix = (
             (config.switch_prior_scale ** 2) or 1.0
         ) * torch.eye(dim_out)
         LVinv_tril, LVinv_logdiag = make_inv_tril_parametrization(
             covariance_matrix
         )
         self.dist = ParametrisedMultivariateNormal(
             m=torch.ones(config.dims.switch) * config.switch_prior_loc,
             LVinv_tril=LVinv_tril,
             LVinv_logdiag=LVinv_logdiag,
             requires_grad_m=config.requires_grad_switch_prior,
             requires_diag_LVinv_tril=False,
             requires_diag_LVinv_logdiag=config.requires_grad_switch_prior,
         )
     else:
         self.dist = ParametrisedConditionalDistribution(
             stem=MLP(
                 dim_in=dim_in,
                 dims=dims_stem,
                 activations=activations_stem,
             ),
             dist_params=nn.ModuleDict(
                 {
                     "loc": nn.Sequential(
                         Linear(dim_in_dist_params, dim_out),
                     ),
                     "scale_tril": DefaultScaleTransform(
                         dim_in_dist_params, dim_out,
                     ),
                 }
             ),
             dist_cls=MultivariateNormal,
         )
Ejemplo n.º 11
0
 def __init__(self, config):
     super().__init__()
     (
         dim_in,
         dim_out,
         dims_stem,
         activations_stem,
         dim_in_dist_params,
     ) = _extract_dims_from_cfg(config)
     dim_in_dist_params = dims_stem[-1] if len(dims_stem) > 0 else dim_in
     self.conditional_dist = ParametrisedConditionalDistribution(
         stem=MLP(
             dim_in=dim_in, dims=dims_stem, activations=activations_stem,
         ),
         dist_params=nn.ModuleDict(
             {
                 "logits": nn.Sequential(
                     Linear(dim_in_dist_params, dim_out),
                     SigmoidLimiter(limits=[-10, 10]),
                 )
             }
         ),
         dist_cls=OneHotCategorical,
     )
Ejemplo n.º 12
0
    def __init__(self, config):
        self.num_hierarchies = 2
        if config.dims.ctrl_encoder not in [None, 0]:
            raise ValueError(
                "no controls. would require different architecture "
                "or mixing with images.")
        shp_enc_out, dim_out_flat_conv = compute_cnn_output_filters_and_dims(
            dims_img=config.dims_img,
            dims_filter=config.dims_filter,
            kernel_sizes=config.kernel_sizes,
            strides=config.strides,
            paddings=config.paddings,
        )

        assert config.dims_encoder[0] is None, (
            "first stem is a conv net. "
            "config is given differently...")
        dims_stem_2 = (  # TODO: really past self?
            32,
            32,
        )
        activations_stem_2 = nn.ReLU()
        dim_out_1 = config.dims.auxiliary
        dim_out_2 = config.dims.switch
        dim_in_dist_params_1 = dim_out_flat_conv
        dim_in_dist_params_2 = (dims_stem_2[-1]
                                if len(dims_stem_2) > 0 else dim_out_flat_conv)

        super().__init__(
            allow_cat_inputs=False,  # images and scalar...
            stem=nn.ModuleList([
                nn.Sequential(
                    Reshape(
                        config.dims_img),  # TxPxB will be flattened before.
                    Conv2d(
                        in_channels=config.dims_img[0],
                        out_channels=config.dims_filter[0],
                        kernel_size=config.kernel_sizes[0],
                        stride=config.strides[0],
                        padding=config.paddings[0],
                    ),
                    nn.ReLU(),
                    Conv2d(
                        in_channels=config.dims_filter[0],
                        out_channels=config.dims_filter[1],
                        kernel_size=config.kernel_sizes[1],
                        stride=config.strides[1],
                        padding=config.paddings[1],
                    ),
                    nn.ReLU(),
                    Conv2d(
                        in_channels=config.dims_filter[1],
                        out_channels=config.dims_filter[2],
                        kernel_size=config.kernel_sizes[2],
                        stride=config.strides[2],
                        padding=config.paddings[2],
                    ),
                    nn.ReLU(),
                    Reshape((dim_out_flat_conv, )),  # Flatten image dims
                ),
                MLP(
                    dim_in=dim_out_flat_conv,
                    dims=dims_stem_2,
                    activations=activations_stem_2,
                ),
            ]),
            dist_params=nn.ModuleList([
                nn.ModuleDict({
                    "loc":
                    nn.Sequential(
                        Linear(
                            in_features=dim_in_dist_params_1,
                            out_features=dim_out_1,
                        ), ),
                    "scale_tril":
                    DefaultScaleTransform(
                        dim_in_dist_params_1,
                        dim_out_1,
                    ),
                }),
                nn.ModuleDict({
                    "loc":
                    nn.Sequential(
                        Linear(
                            in_features=dim_in_dist_params_2,
                            out_features=dim_out_2,
                        ), ),
                    "scale_tril":
                    DefaultScaleTransform(
                        dim_in_dist_params_2,
                        dim_out_2,
                    ),
                }),
            ]),
            dist_cls=[MultivariateNormal, MultivariateNormal],
        )
Ejemplo n.º 13
0
 def __init__(self, config):
     shp_enc_out, dim_out_flat_conv = compute_cnn_output_filters_and_dims(
         dims_img=config.dims_img,
         dims_filter=config.dims_filter,
         kernel_sizes=config.kernel_sizes,
         strides=config.strides,
         paddings=config.paddings,
     )
     if not config.requires_grad_Q and isinstance(config.init_scale_Q_diag,
                                                  float):
         fixed_max_scale = True
     elif config.requires_grad_Q and not isinstance(
             config.init_scale_Q_diag, float):
         fixed_max_scale = False
     else:
         raise ValueError("unclear what encoder scale rectifier to use.")
     super().__init__(
         stem=nn.Sequential(
             Reshape(config.dims_img),  # TxPxB will be flattened before.
             nn.ZeroPad2d(padding=[0, 1, 0, 1]),
             Conv2d(
                 in_channels=config.dims_img[0],
                 out_channels=config.dims_filter[0],
                 kernel_size=config.kernel_sizes[0],
                 stride=config.strides[0],
                 padding=0,
             ),
             nn.ReLU(),
             nn.ZeroPad2d(padding=[0, 1, 0, 1]),
             Conv2d(
                 in_channels=config.dims_filter[0],
                 out_channels=config.dims_filter[1],
                 kernel_size=config.kernel_sizes[1],
                 stride=config.strides[1],
                 padding=0,
             ),
             nn.ReLU(),
             nn.ZeroPad2d(padding=[0, 1, 0, 1]),
             Conv2d(
                 in_channels=config.dims_filter[1],
                 out_channels=config.dims_filter[2],
                 kernel_size=config.kernel_sizes[2],
                 stride=config.strides[2],
                 padding=0,
             ),
             nn.ReLU(),
             Reshape((dim_out_flat_conv, )),  # Flatten image dims
         ),
         dist_params=nn.ModuleDict({
             "loc":
             Linear(
                 in_features=dim_out_flat_conv,
                 out_features=config.dims.auxiliary,
             ),
             "scale":
             nn.Sequential(
                 Linear(
                     in_features=dim_out_flat_conv,
                     out_features=config.dims.auxiliary,
                 ),
                 ScaledSqrtSigmoid(max_scale=config.init_scale_Q_diag),
             ) if fixed_max_scale else DefaultScaleTransform(
                 dim_out_flat_conv,
                 config.dims.auxiliary,
                 make_diag_cov_matrix=False,
             ),
         }),
         dist_cls=IndependentNormal,
     )
Ejemplo n.º 14
0
 def __init__(self, config):
     shp_enc_out, dim_out_flat_conv = compute_cnn_output_filters_and_dims(
         dims_img=config.dims_img,
         dims_filter=config.dims_filter,
         kernel_sizes=config.kernel_sizes,
         strides=config.strides,
         paddings=config.paddings,
     )
     super().__init__(
         stem=nn.Sequential(
             Linear(
                 in_features=config.dims.auxiliary,
                 out_features=int(np.prod(shp_enc_out)),
             ),
             Reshape(shp_enc_out),  # TxPxB will be flattened before.
             Conv2d(
                 in_channels=shp_enc_out[0],
                 out_channels=config.dims_filter[-1] *
                 config.upscale_factor**2,
                 kernel_size=config.kernel_sizes[-1],
                 stride=1,  # Pixelshuffle instead.
                 padding=config.paddings[-1],
             ),
             nn.PixelShuffle(upscale_factor=config.upscale_factor),
             nn.ReLU(),
             Conv2d(
                 in_channels=config.dims_filter[-1],
                 out_channels=config.dims_filter[-2] *
                 config.upscale_factor**2,
                 kernel_size=config.kernel_sizes[-2],
                 stride=1,  # Pixelshuffle instead.
                 padding=config.paddings[-2],
             ),
             nn.PixelShuffle(upscale_factor=config.upscale_factor),
             nn.ReLU(),
             Conv2d(
                 in_channels=config.dims_filter[-2],
                 out_channels=config.dims_filter[-3] *
                 config.upscale_factor**2,
                 kernel_size=config.kernel_sizes[-3],
                 stride=1,  # Pixelshuffle instead.
                 padding=config.paddings[-3],
             ),
             nn.PixelShuffle(upscale_factor=config.upscale_factor),
             nn.ReLU(),
         ),
         dist_params=nn.ModuleDict({
             "logits":
             nn.Sequential(
                 Conv2d(
                     in_channels=config.dims_filter[-3],
                     out_channels=1,
                     kernel_size=1,
                     stride=1,
                     padding=0,
                 ),
                 Reshape((config.dims.target, )),
             )
         }),
         dist_cls=IndependentBernoulli,
     )