def __init__(self, specs: BehaviorSpec,
                 settings: CuriositySettings) -> None:
        super().__init__()
        self._policy_specs = specs
        state_encoder_settings = NetworkSettings(
            normalize=False,
            hidden_units=settings.encoding_size,
            num_layers=2,
            vis_encode_type=EncoderType.SIMPLE,
            memory=None,
        )
        self._state_encoder = NetworkBody(specs.observation_shapes,
                                          state_encoder_settings)

        self._action_flattener = ModelUtils.ActionFlattener(specs)

        self.inverse_model_action_prediction = torch.nn.Sequential(
            LinearEncoder(2 * settings.encoding_size, 1, 256),
            linear_layer(256, self._action_flattener.flattened_size),
        )

        self.forward_model_next_state_prediction = torch.nn.Sequential(
            LinearEncoder(
                settings.encoding_size + self._action_flattener.flattened_size,
                1, 256),
            linear_layer(256, settings.encoding_size),
        )
    def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
        super().__init__()
        self._policy_specs = specs
        self._use_vail = settings.use_vail
        self._settings = settings

        state_encoder_settings = NetworkSettings(
            normalize=False,
            hidden_units=settings.encoding_size,
            num_layers=2,
            vis_encode_type=EncoderType.SIMPLE,
            memory=None,
        )
        self._state_encoder = NetworkBody(specs.observation_shapes,
                                          state_encoder_settings)

        self._action_flattener = ModelUtils.ActionFlattener(specs)

        encoder_input_size = settings.encoding_size
        if settings.use_actions:
            encoder_input_size += (self._action_flattener.flattened_size + 1
                                   )  # + 1 is for done

        self.encoder = torch.nn.Sequential(
            linear_layer(encoder_input_size, settings.encoding_size),
            Swish(),
            linear_layer(settings.encoding_size, settings.encoding_size),
            Swish(),
        )

        estimator_input_size = settings.encoding_size
        if settings.use_vail:
            estimator_input_size = self.z_size
            self._z_sigma = torch.nn.Parameter(torch.ones((self.z_size),
                                                          dtype=torch.float),
                                               requires_grad=True)
            self._z_mu_layer = linear_layer(
                settings.encoding_size,
                self.z_size,
                kernel_init=Initialization.KaimingHeNormal,
                kernel_gain=0.1,
            )
            self._beta = torch.nn.Parameter(torch.tensor(self.initial_beta,
                                                         dtype=torch.float),
                                            requires_grad=False)

        self._estimator = torch.nn.Sequential(
            linear_layer(estimator_input_size, 1), torch.nn.Sigmoid())