def __init__(self, observation_spaces_dict: Dict[Union[str, int],
                                                     spaces.Dict],
                 action_spaces_dict: Dict[Union[str, int], spaces.Dict],
                 networks: CollectionOfConfigType):
        super().__init__(observation_spaces_dict, action_spaces_dict)
        assert len(networks) == 1
        network = networks[0]

        flat_action_space = flat_structured_space(self._action_spaces_dict)
        obs_shapes_flat = flat_structured_shapes(self._obs_shapes)

        # Infer the critic out shapes. When all action heads in a given state are discrete the discrete version of the
        #   state-action critic is used that outputs a value for each possible action (for each action). Otherwise
        #   the more general version is used which returns one value for a given state and action.
        critic_output_shapes = dict()
        if all(self._only_discrete_spaces.values()):
            for act_key, act_space in flat_action_space.spaces.items():
                critic_output_shapes[act_key + '_q_values'] = (act_space.n, )
        else:
            for act_key, act_space in flat_action_space.spaces.items():
                if isinstance(act_space, spaces.Discrete):
                    obs_shapes_flat[act_key] = (act_space.n, )
                else:
                    obs_shapes_flat[act_key] = act_space.sample().shape
            critic_output_shapes['q_value'] = (1, )

        # initialize critic
        model_registry = Factory(base_type=nn.Module)
        self._critics = {
            0:
            model_registry.instantiate(network,
                                       obs_shapes=obs_shapes_flat,
                                       output_shapes=critic_output_shapes)
        }
Esempio n. 2
0
    def __init__(self, observation_spaces_dict: Dict[Union[str, int],
                                                     spaces.Dict],
                 agent_counts_dict: Dict[StepKeyType, int],
                 networks: CollectionOfConfigType):
        super().__init__(observation_spaces_dict, agent_counts_dict)

        # initialize critic
        model_registry = Factory(base_type=nn.Module)
        networks = list_to_dict(networks)
        self._critics = dict()
        for idx, (key, net_config) in enumerate(networks.items()):
            step_obs_shapes = self._obs_shapes[key]
            if idx > 0:
                step_obs_shapes = {
                    **step_obs_shapes, self.prev_value_key:
                    self.prev_value_shape
                }
            self._critics[key] = model_registry.instantiate(
                networks[key], obs_shapes=step_obs_shapes)
Esempio n. 3
0
    def __init__(self, observation_spaces_dict: Dict[StepKeyType, spaces.Dict],
                 agent_counts_dict: Dict[StepKeyType, int],
                 networks: ConfigType, stack_observations: bool):
        super().__init__(observation_spaces_dict, agent_counts_dict)
        assert len(networks) == 1
        self.stack_observations = stack_observations
        network = networks[0]

        obs_shapes_flat = self._obs_shapes
        if self.stack_observations:
            obs_shapes_flat = stacked_shapes(obs_shapes_flat,
                                             self._agent_counts_dict)
        obs_shapes_flat = flat_structured_shapes(obs_shapes_flat)
        self._obs_shapes = {0: obs_shapes_flat}

        # initialize critic
        model_registry = Factory(base_type=nn.Module)
        self._critics = {
            0: model_registry.instantiate(network, obs_shapes=obs_shapes_flat)
        }
Esempio n. 4
0
def test_raises_exception_on_invalid_registry_value():
    with pytest.raises(ImportError):
        registry = Factory(base_type=DummyObservationConversion)
        registry.instantiate(config={"_target_": "wrong_key"})
Esempio n. 5
0
def test_raises_exception_on_invalid_type():
    with pytest.raises(AssertionError):
        registry = Factory(base_type=DummyObservationConversion)
        registry.instantiate(config=DictActionConversion())
Esempio n. 6
0
def test_returns_if_arg_already_instantiated():
    obs_conv = CustomDummyObservationConversion(attr=1)
    registry = Factory(base_type=DummyObservationConversion)
    obj = registry.instantiate(config=obs_conv)
    assert obj == obj