Beispiel #1
0
    def __init__(
        self,
        observability: str,
        latent_space: str,
        obs_spaces: collections.OrderedDict,
        # attr_embed_size,
        z_dims: int,
        z_std_clip_max: float,
        goal_vector_obs_space,
        hidden_size: int,
        base_model: str,
        # base_kwargs: Dict,
    ):
        super().__init__()

        assert latent_space in ['gaussian']

        # self.encoder_type = encoder_type
        self.latent_space = latent_space
        self.z_dims = z_dims
        self.z_std_clip_max = z_std_clip_max
        self.hidden_size = hidden_size
        self.observability = observability
        assert len(goal_vector_obs_space.shape) == 1
        self.goal_vector_dims = goal_vector_obs_space.shape[0]
        # self.base_model = base_model

        self.base = base_model
        assert self.is_recurrent

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.fc12 = init_(
            nn.Linear(hidden_size + self.goal_vector_dims, 2 * z_dims))
Beispiel #2
0
    def __init__(
        self,
        observability: str,
        action_dims: int,
        latent_space: str,
        obs_spaces: collections.OrderedDict,
        # attr_embed_size,
        z_dims: int,
        z_std_clip_max: float,
        hidden_size: int,
        base_model: str,
        base_kwargs: Dict,
        policy_base_kwargs: Dict,
    ):
        # assert 'goal_vector' not in policy_base_kwargs['obs_spaces']
        new_base_kwargs = policy_base_kwargs.copy()
        self._z_dims = z_dims
        self.goal_vector_obs_space = \
            policy_base_kwargs['obs_spaces']['goal_vector']
        self.goal_vector_dims = self.goal_vector_obs_space.shape[0]
        new_base_kwargs['obs_spaces'].pop('goal_vector')
        self.z_std_clip_max = z_std_clip_max

        super().__init__(
            observability=observability,
            action_dims=action_dims,
            base_model=base_model,
            base_kwargs=new_base_kwargs,
        )

        init_ = lambda m: init(m, init_normc_, lambda x: nn.init.constant_(
            x, 0))

        self.actor_net = nn.Sequential(
            init_(nn.Linear(hidden_size + z_dims, hidden_size)),
            nn.Tanh(),
        )

        self.critic_net = nn.Sequential(
            init_(nn.Linear(hidden_size + z_dims, hidden_size)),
            nn.Tanh(),
            init_(nn.Linear(hidden_size, 1)),
        )

        self.z_enc_net = init_(nn.Linear(hidden_size \
            + self.goal_vector_dims, 2 * z_dims))

        # self.ib_encoder = IBSupervisedEncoder(
        #     observability=observability,
        #     latent_space=latent_space,
        #     obs_spaces=obs_spaces,
        #     z_dims=z_dims,
        #     goal_vector_obs_space=self.goal_vector_obs_space,
        #     z_std_clip_max=z_std_clip_max,
        #     hidden_size=hidden_size,
        #     base_model=self.base,
        #     # base_kwargs=base_kwargs,
        # )
        assert 'goal_vector' not in self.base.obs_keys
Beispiel #3
0
    def __init__(self, num_inputs, num_outputs):
        super(DiagGaussian, self).__init__()

        init_ = lambda m: init(m, init_normc_, lambda x: nn.init.constant_(
            x, 0))

        self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))
        self.logstd = AddBias(torch.zeros(num_outputs))
Beispiel #4
0
    def __init__(self, num_inputs, num_outputs):
        super(Categorical, self).__init__()

        init_ = lambda m: init(m,
                               nn.init.orthogonal_,
                               lambda x: nn.init.constant_(x, 0),
                               gain=0.01)

        self.linear = init_(nn.Linear(num_inputs, num_outputs))
Beispiel #5
0
    def __init__(
        self,
        num_inputs: int,
        recurrent: bool = False,
        hidden_size: int = 64,
        use_critic: bool = True,
        critic_detach: bool = True,
    ):
        super().__init__(recurrent, num_inputs, hidden_size)
        self.critic_detach = critic_detach

        self.use_critic = use_critic
        if recurrent:
            num_inputs = hidden_size

        init_ = lambda m: init(m, init_normc_, lambda x: nn.init.constant_(
            x, 0))

        self.actor = nn.Sequential(
            init_(nn.Linear(num_inputs, hidden_size)),
            # nn.LeakyReLU(0.1),
            # nn.ELU(),
            nn.Tanh(),
            # init_(nn.Linear(hidden_size, hidden_size)),
            # nn.LeakyReLU(0.1),
            # nn.ELU(),
            # init_(nn.Linear(hidden_size, hidden_size)),
            # nn.LeakyReLU(0.1),
        )

        if use_critic:
            self.critic = nn.Sequential(
                init_(nn.Linear(num_inputs, hidden_size)),
                # nn.LeakyReLU(0.1),
                # nn.ELU(),
                nn.Tanh(),
                # init_(nn.Linear(hidden_size, hidden_size)),
                # # nn.LeakyReLU(0.1),
                # nn.ELU(),
                # init_(nn.Linear(hidden_size, hidden_size)),
                # nn.LeakyReLU(0.1),
            )
            self.critic_linear = init_(nn.Linear(hidden_size, 1))

        self.train()
Beispiel #6
0
    def __init__(
        self,
        embed_type: str,
        input_attr_dims: Tuple[int],
        embed_size: int,
        hidden_size: int,  #TODO: Stop overloading hidden_size
        # num_attributes,
        output_size: Optional[int] = None):
        super().__init__()

        self.embed_type = embed_type
        self.embed_size = embed_size
        self.input_attr_dims = input_attr_dims
        # self.num_attributes = num_attributes

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        if embed_type == 'one-hot':
            self.main_embed = nn.Embedding(input_attr_dims, embed_size)

        elif embed_type == 'k-hot':
            assert len(input_attr_dims) > 1, \
                "Use one-hot for single attribute"
            assert output_size != None, \
                "Output size needed for k-hot embeddings"

            self.main_embed = nn.ModuleList([
                nn.Sequential(
                    nn.Embedding(dim, embed_size),
                    init_(nn.Linear(embed_size, hidden_size)),
                    nn.LeakyReLU(0.1),
                    init_(nn.Linear(hidden_size, hidden_size)),
                ) for dim in input_attr_dims
            ])

            self.fc = nn.Sequential(
                    init_(nn.Linear(len(input_attr_dims) * hidden_size, \
                        hidden_size)),
                nn.LeakyReLU(0.1),
                init_(nn.Linear(hidden_size, output_size)),
            )
        self._init_embedding_tables()
Beispiel #7
0
    def __init__(self,
                 input_channels,
                 omega_option_dims,
                 input_attr_dims,
                 recurrent=False,
                 hidden_size=512,
                 pretrained_encoder=False,
                 agent_cfg_dims=None):
        super().__init__(recurrent, hidden_size, hidden_size)

        self.input_channels = input_channels
        self.pretrained_encoder = pretrained_encoder
        self.agent_cfg_dims = agent_cfg_dims
        self.input_attr_dims = input_attr_dims
        self.hidden_size = hidden_size

        self.omega_option_dims = omega_option_dims

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.omega_fc_actor = nn.Sequential(
            init_(nn.Linear(omega_option_dims, hidden_size)),
            nn.ReLU(),
            init_(nn.Linear(hidden_size, hidden_size)),
            nn.ReLU(),
        )

        self.omega_fc_critic = nn.Sequential(
            init_(nn.Linear(omega_option_dims, hidden_size)),
            nn.ReLU(),
            init_(nn.Linear(hidden_size, hidden_size)),
            nn.ReLU(),
        )

        encoder_dim = (2 * 2 + 7 * 7 + 15 * 15) * 3
        # base_feat_dim = encoder_dim + embed_size

        if pretrained_encoder:
            self.encoder = Encoder()
            self.after_encoder = nn.Sequential(
                init_(nn.Linear(encoder_dim, hidden_size)),
                nn.ReLU(),
            )
            self.triplet_fc = nn.Sequential(
                init_(nn.Linear(3 * hidden_size, hidden_size)),
                nn.ReLU(),
            )
        else:
            init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                                   constant_(x, 0),
                                   nn.init.calculate_gain('relu'))

            self.encoder = nn.Sequential(
                # 30 x 30
                init_(nn.Conv2d(input_channels, 10, 1, stride=1)),
                # nn.MaxPool2d(2, 2),
                nn.ReLU(),
                # 30 x 30
                init_(nn.Conv2d(10, 32, 4, stride=2, padding=1)),
                nn.ReLU(),
                # 15 x 15
                init_(nn.Conv2d(32, 32, 3, stride=2, padding=1)),
                # 8 x 8
                nn.ReLU(),
                Flatten(),
                init_(nn.Linear(32 * 8 * 8, hidden_size)),
                nn.ReLU(),
            )

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0))

        if agent_cfg_dims is not None:
            fc_input_size = hidden_size + agent_cfg_dims
        else:
            fc_input_size = hidden_size

        self.actor_fc = nn.Sequential(
            init_(nn.Linear(fc_input_size + hidden_size, hidden_size)),
            nn.ReLU(),
            init_(nn.Linear(hidden_size, hidden_size)),
        )

        self.critic_linear = nn.Sequential(
            init_(nn.Linear(fc_input_size + hidden_size, hidden_size)),
            nn.ReLU(),
            init_(nn.Linear(hidden_size, 1)),
        )

        #self.target_embed = AttributeEmbedding(
        #    embed_type='k-hot',
        #    input_dim=self.input_attr_dims,
        #    embed_size=self.hidden_size,
        #    hidden_size=self.hidden_size,
        #    output_size=self.hidden_size)

        self.train()
Beispiel #8
0
    def __init__(self,
                 input_channels,
                 goal_output_size,
                 goal_attr_dims,
                 state_encoder_hidden_size,
                 recurrent=False,
                 pretrained_encoder=False,
                 agent_cfg_dims=None):

        super().__init__(recurrent, state_encoder_hidden_size,
                         state_encoder_hidden_size)

        # self.model = model
        self.input_channels = input_channels
        self.pretrained_encoder = pretrained_encoder
        self.agent_cfg_dims = agent_cfg_dims
        self.goal_attr_dims = goal_attr_dims
        self.state_encoder_hidden_size = state_encoder_hidden_size

        self.goal_output_size = goal_output_size

        self.goal_fc_actor = nn.Sequential(
            nn.Linear(goal_output_size, goal_output_size),
            nn.ReLU(),
        )

        self.goal_fc_critic = nn.Sequential(
            nn.Linear(goal_output_size, goal_output_size),
            nn.ReLU(),
        )

        encoder_dim = (2 * 2 + 7 * 7 + 15 * 15) * 3
        # base_feat_dim = encoder_dim + embed_size

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        self.map_encoder = nn.Sequential(
            # 30 x 30
            init_(nn.Conv2d(input_channels, 10, 1, stride=1)),
            # nn.MaxPool2d(2, 2),
            nn.ReLU(),
            # 30 x 30
            init_(nn.Conv2d(10, 10, 4, stride=2, padding=1)),
            nn.ReLU(),
            # 15 x 15
            init_(nn.Conv2d(10, 10, 3, stride=2, padding=1)),
            ## 8 x 8
            #nn.ReLU(),
            Flatten(),
            init_(nn.Linear(10 * 8 * 8, state_encoder_hidden_size)),
            nn.ReLU(),
        )

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0))

        if agent_cfg_dims is not None:
            fc_input_size = self.state_encoder_hidden_size + \
                             self.agent_cfg_dims + \
                             self.goal_output_size
        else:
            fc_input_size = hidden_size

        self.actor_fc = nn.Sequential(
            nn.Linear(fc_input_size, state_encoder_hidden_size),
            nn.ReLU(),
            nn.Linear(state_encoder_hidden_size, state_encoder_hidden_size),
        )

        self.critic_linear = nn.Sequential(
            init_(nn.Linear(fc_input_size, state_encoder_hidden_size)),
            nn.ReLU(),
            init_(nn.Linear(state_encoder_hidden_size, 1)),
        )

        self.train()
Beispiel #9
0
    def __init__(
        self,
        obs_spaces: collections.OrderedDict,
        recurrent: bool = False,
        hidden_size: int = 64,
        use_critic: bool = True,
        critic_detach: bool = True,
    ):
        num_inputs = 0
        self.obs_keys = obs_spaces.keys()
        self.image_space = obs_spaces['image']
        mlp_obs_spaces = obs_spaces.copy()
        mlp_obs_spaces.update({
            'image':
            spaces.Box(
                low=0.0,  # Arbitrary value
                high=1.0,  # Arbitrary value
                shape=(hidden_size, ),
                dtype='float',
            )
        })
        self.mlp_obs_keys = mlp_obs_spaces.keys()

        super().__init__(
            obs_spaces=mlp_obs_spaces,
            recurrent=recurrent,
            hidden_size=hidden_size,
            use_critic=use_critic,
            critic_detach=critic_detach,
        )

        _neg_slope = 0.1
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0),
            nn.init.calculate_gain('leaky_relu', param=_neg_slope))

        NEW_CNN = True

        H, W, num_channels = self.image_space.shape
        if NEW_CNN:
            self.cnn = nn.Sequential(
                init_(nn.Conv2d(num_channels, 16, (3, 3), padding=1)),
                nn.ReLU(),
                # nn.MaxPool2d((2, 2)),
                init_(nn.Conv2d(16, 32, (2, 2))),
                nn.ReLU(),
                init_(nn.Conv2d(32, 64, (2, 2))),
                nn.ReLU(),
                Flatten(),
            )
        else:
            self.cnn = nn.Sequential(
                init_(nn.Conv2d(num_channels, 16, 1, stride=1)),
                # nn.LeakyReLU(_neg_slope),
                nn.ELU(),
                init_(nn.Conv2d(16, 8, 3, stride=1, padding=2)),
                # nn.LeakyReLU(_neg_slope),
                nn.ELU(),
                # init_(nn.Conv2d(64, 64, 5, stride=1, padding=2)),
                # nn.LeakyReLU(_neg_slope),
                Flatten(),
            )
        output_h_w, out_channels = utils.conv_sequential_output_shape((H, W),
                                                                      self.cnn)
        h_w_prod = output_h_w[0] * output_h_w[1]
        self.fc = nn.Sequential(
            init_(nn.Linear(out_channels * h_w_prod, hidden_size)),
            # nn.LeakyReLU(_neg_slope),
            # nn.ELU(),
        )
        self.apply(initialize_parameters)
Beispiel #10
0
    def __init__(
            self,
            input_channels,
            target_dim,
            target_embed_type,
            # num_attributes,
            embed_size,
            recurrent=False,
            hidden_size=512,
            input_size=64,
            pretrained_encoder=False,
            agent_cfg_dims=None):
        super().__init__(recurrent, hidden_size, hidden_size)

        self.input_size = input_size
        self.target_dim = target_dim
        self.embed_size = embed_size
        self.input_channels = input_channels
        self.pretrained_encoder = pretrained_encoder
        self.agent_cfg_dims = agent_cfg_dims

        # self.num_attributes = num_attributes
        self.target_embed_type = target_embed_type

        self.target_embed_actor = AttributeEmbedding(
            embed_type=target_embed_type,
            input_dim=target_dim,
            embed_size=embed_size,
            hidden_size=hidden_size,
            output_size=hidden_size)

        self.target_embed_critic = AttributeEmbedding(
            embed_type=target_embed_type,
            input_dim=target_dim,
            embed_size=embed_size,
            hidden_size=hidden_size,
            output_size=hidden_size)

        self.init_embedding_weights()

        encoder_dim = (2 * 2 + 7 * 7 + 15 * 15) * 3
        # base_feat_dim = encoder_dim + embed_size

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        if pretrained_encoder:
            self.encoder = Encoder()
            self.after_encoder = nn.Sequential(
                nn.Linear(encoder_dim, hidden_size),
                nn.ReLU(),
            )
            self.triplet_fc = nn.Sequential(
                nn.Linear(3 * hidden_size, hidden_size),
                nn.ReLU(),
            )
        else:
            # [NOTE] : If we switch to some other gridworld, this has to be
            # taken care of.

            if self.input_size == 30:
                self.encoder = nn.Sequential(
                    # 30 x 30
                    init_(nn.Conv2d(input_channels, 10, 1, stride=1)),
                    # nn.MaxPool2d(2, 2),
                    nn.ReLU(),
                    # 30 x 30
                    init_(nn.Conv2d(10, 32, 4, stride=2, padding=1)),
                    nn.ReLU(),
                    # 15 x 15
                    init_(nn.Conv2d(32, 32, 3, stride=2, padding=1)),
                    # 8 x 8
                    nn.ReLU(),
                    Flatten(),
                    init_(nn.Linear(32 * 8 * 8, hidden_size)),
                    nn.ReLU(),
                )
            else:
                raise ValueError

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0))

        if agent_cfg_dims is not None:
            fc_input_size = hidden_size + agent_cfg_dims
        else:
            fc_input_size = hidden_size

        if target_embed_type == 'one-hot':
            self.actor_fc = nn.Sequential(
                nn.Linear(fc_input_size + embed_size, hidden_size),
                nn.ReLU(),
            )
            self.critic_linear = init_(nn.Linear(hidden_size + embed_size, 1))

        elif target_embed_type == 'k-hot':
            self.actor_fc = nn.Sequential(
                nn.Linear(fc_input_size + hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
            )
            self.critic_linear = nn.Sequential(
                init_(nn.Linear(fc_input_size + hidden_size, hidden_size)),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, 1)),
            )

        self.train()
Beispiel #11
0
    def __init__(
        self,
        observability: str,
        latent_space: str,
        obs_spaces: collections.OrderedDict,
        # attr_embed_size,
        z_dims: int,
        z_std_clip_max: float,
        hidden_size: int,
        base_model: str,
        base_kwargs: Dict,
    ):
        super().__init__()

        # assert input_type in ['goal_and_initial_state']
        # assert encoder_type in ['single', 'poe']
        assert 'mission' not in base_kwargs['obs_spaces']
        # assert 'omega' in base_kwargs['obs_spaces']
        assert latent_space in ['gaussian']

        # self.encoder_type = encoder_type
        self.latent_space = latent_space
        self.z_dims = z_dims
        self.z_std_clip_max = z_std_clip_max
        self.hidden_size = hidden_size
        self.observability = observability
        self.base_model = base_model

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        if base_kwargs is None:
            base_kwargs = {}

        # if input_type == 'goal_and_initial_state':
        if observability == 'full':
            if self.base_model == 'mlp':
                self.base = FlattenMLPBase(**base_kwargs)
            elif self.base_model == 'cnn-mlp':
                self.base = CNNPlusMLPBase(**base_kwargs)
            else:
                raise ValueError
        else:
            raise NotImplementedError

        # # Only encode state observation if it is provided as input
        # self.use_state_encoder = \
        #     self.input_type == 'goal_and_initial_state'

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        # Encoding attributes
        # if encoder_type == 'single':
        # self.fc = nn.Sequential(
        #     init_(nn.Linear(omega_option_dims, hidden_size)),
        #     nn.LeakyReLU(0.1),
        #     init_(nn.Linear(hidden_size, hidden_size)),
        # )

        # if self.use_state_encoder:
        # self.fc = nn.Sequential(
        #     init_(nn.Linear(hidden_size, hidden_size)),
        #     nn.LeakyReLU(0.1),
        # )

        if self.latent_space == 'gaussian':
            self.fc12 = init_(nn.Linear(hidden_size, 2 * z_dims))

        elif self.latent_space == 'categorical':
            self.fc_logits = nn.Sequential(
                init_(nn.Linear(hidden_size, hidden_size)),
                nn.LeakyReLU(0.1),
            )
            self.dist = Categorical(hidden_size, self.z_dims)

        else:
            raise ValueError
Beispiel #12
0
    def __init__(
        self,
        input_type: str,
        ic_mode: str,
        observability: str,
        option_space: str,
        omega_option_dims: int,
        hidden_size: int,
        base_model: str,
        base_kwargs: Dict,
    ):
        super().__init__()

        assert input_type in \
            ['final_state', 'final_and_initial_state']
        assert option_space in ['continuous', 'discrete']
        assert 'mission' not in base_kwargs['obs_spaces']
        assert ic_mode in ['vic', 'diyan', 'valor']

        if ic_mode != 'vic':
            input_type = 'final_state'

        self.input_type = input_type
        self.ic_mode = ic_mode
        self.hidden_size = hidden_size
        self.option_space = option_space
        self.omega_option_dims = omega_option_dims
        self.observability = observability
        self.base_model = base_model

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        if base_kwargs is None:
            base_kwargs = {}

        if ic_mode == 'valor':
            base_kwargs['recurrent'] = True

        if self.base_model == 'cnn-mlp' and \
        'image' not in base_kwargs['obs_spaces']:
            self.base_model = 'mlp'
            print("Switching to MLP for TrajectoryEncoder since no Image"
                  " present in obs_spaces!")

        assert observability == 'full'
        if self.base_model == 'mlp':
            self.base = FlattenMLPBase(**base_kwargs)
        elif self.base_model == 'cnn-mlp':
            self.base = CNNPlusMLPBase(**base_kwargs)
        else:
            raise ValueError

        state_feat_dim = base_kwargs['hidden_size']
        if input_type == 'final_and_initial_state':
            state_feat_dim *= 2

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        self.fc = nn.Sequential(
            init_(nn.Linear(state_feat_dim, hidden_size)),
            # nn.LeakyReLU(0.1),
            nn.ELU(),
            # init_(nn.Linear(hidden_size, hidden_size)),
            # nn.LeakyReLU(0.1),
            # init_(nn.Linear(hidden_size, hidden_size)),
        )

        if self.option_space == 'continuous':
            self.fc12 = init_(nn.Linear(hidden_size, 2 * omega_option_dims))
        else:
            self.fc_logits = nn.Sequential(
                init_(nn.Linear(hidden_size, hidden_size)),
                # nn.LeakyReLU(0.1),
                nn.ELU(),
            )
            self.dist = Categorical(hidden_size, self.omega_option_dims)
Beispiel #13
0
    def __init__(
        self,
        observability: str,
        input_type: str,
        encoder_type: str,
        obs_spaces: collections.OrderedDict,
        attr_embed_size: int,
        option_space: str,
        omega_option_dims: int,
        hidden_size: int,
        base_model: str,
        base_kwargs: Dict,
    ):
        '''
        Arguments:
            observability: Only 'full' is supported i.e. MDP setting

            input_type: 'goal_and_initial_state' or 'goal_only'; for
                taking as input the goal specification (attributes)
                and initial state in the first setting and just the
                goal specification in the second setting.

            encoder_type: 'single' or 'poe', the latter is a product
                of experts (POE) encoding which predicts the location
                and scale of each gaussian Q(\omega | A_k), where k is
                a specified attribute and computes the final probability
                as a multiplication of each predicted gaussian along
                with the prior P(\omega) which is N(0, I).

            input_attr_dim: Tuple specifying (n_0, n_1, ... n_K-1) i.e.
                the number of values for each attribute.

            attr_embed_size: See 'embed_size' in models.AttributeEmbedding

            hidden_size: Overloaded parameter specfying hidden size of
                MLPs, AttributeEmbedding objects and CNNs if any.

            agent_cfg_dims: The size of agent's config specification.
                This is needed when options are conditioned on the
                initial state observation which includes environment
                config and agent config.

            input_channels: The input channels of CNN used to encode
                environment config (cell map).
        '''
        super().__init__()

        assert input_type in ['goal_and_initial_state', 'goal_only']
        assert encoder_type in ['single', 'poe']
        assert 'mission' not in base_kwargs['obs_spaces']
        assert option_space in ['continuous', 'discrete']

        # self.agent_pos_dims = obs_spaces['agent_pos'].shape[0]
        self.input_attr_dims = obs_spaces['mission'].nvec
        self.attr_embed_size = attr_embed_size
        self.input_type = input_type
        self.encoder_type = encoder_type
        self.option_space = option_space
        self.omega_option_dims = omega_option_dims
        self.hidden_size = hidden_size
        self.observability = observability
        self.base_model = base_model

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        if base_kwargs is None:
            base_kwargs = {}

        if input_type == 'goal_and_initial_state':
            if observability == 'full':
                if self.base_model == 'mlp':
                    self.base = FlattenMLPBase(**base_kwargs)
                elif self.base_model == 'cnn-mlp':
                    self.base = CNNPlusMLPBase(**base_kwargs)
                else:
                    raise ValueError
            else:
                raise NotImplementedError

        # Only encode state observation if it is provided as input
        self.use_state_encoder = \
            self.input_type == 'goal_and_initial_state'

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))

        # Encoding attributes
        if encoder_type == 'single':
            self.main_embed = AttributeEmbedding(
                embed_type='k-hot',
                input_attr_dims=self.input_attr_dims,
                embed_size=self.attr_embed_size,
                hidden_size=self.hidden_size,
                output_size=self.hidden_size)

            if self.use_state_encoder:
                self.fc = nn.Sequential(
                    init_(nn.Linear(hidden_size * 2, hidden_size)),
                    nn.LeakyReLU(0.1),
                )

            if self.option_space == 'continuous':
                self.fc12 = init_(nn.Linear(hidden_size,
                                            2 * omega_option_dims))
            else:
                self.fc_logits = nn.Sequential(
                    init_(nn.Linear(hidden_size, hidden_size)),
                    nn.LeakyReLU(0.1),
                )
                self.dist = Categorical(hidden_size, self.omega_option_dims)

        elif encoder_type == 'poe':
            if self.option_space == 'discrete':
                raise NotImplementedError
            assert len(self.input_attr_dims) > 1, \
                "Use one-hot for single attribute"
            # assert output_size != None, \
            #     "Output size needed for k-hot embeddings"

            self.poe_embed = nn.ModuleList([
                nn.Sequential(
                    nn.Embedding(dim, hidden_size),
                    init_(nn.Linear(hidden_size, hidden_size)),
                    nn.LeakyReLU(0.1),
                    init_(nn.Linear(hidden_size, hidden_size)),
                ) for dim in self.input_attr_dims
            ])

            if self.use_state_encoder:
                self.fc_poe = nn.ModuleList([
                    nn.Sequential(
                        init_(nn.Linear(hidden_size + hidden_size,
                                        hidden_size)),
                        nn.LeakyReLU(0.1),
                        init_(nn.Linear(hidden_size, hidden_size)),
                    ) for _ in self.input_attr_dims
                ])

            self.fc12 = nn.ModuleList(
                    [init_(nn.Linear(hidden_size, 2 * omega_option_dims)) \
                        for _ in self.input_attr_dims])