def __init__(self, observation_spaces_dict: Dict[Union[str, int], spaces.Dict], action_spaces_dict: Dict[Union[str, int], spaces.Dict], networks: CollectionOfConfigType): super().__init__(observation_spaces_dict, action_spaces_dict) assert len(networks) == 1 network = networks[0] flat_action_space = flat_structured_space(self._action_spaces_dict) obs_shapes_flat = flat_structured_shapes(self._obs_shapes) # Infer the critic out shapes. When all action heads in a given state are discrete the discrete version of the # state-action critic is used that outputs a value for each possible action (for each action). Otherwise # the more general version is used which returns one value for a given state and action. critic_output_shapes = dict() if all(self._only_discrete_spaces.values()): for act_key, act_space in flat_action_space.spaces.items(): critic_output_shapes[act_key + '_q_values'] = (act_space.n, ) else: for act_key, act_space in flat_action_space.spaces.items(): if isinstance(act_space, spaces.Discrete): obs_shapes_flat[act_key] = (act_space.n, ) else: obs_shapes_flat[act_key] = act_space.sample().shape critic_output_shapes['q_value'] = (1, ) # initialize critic model_registry = Factory(base_type=nn.Module) self._critics = { 0: model_registry.instantiate(network, obs_shapes=obs_shapes_flat, output_shapes=critic_output_shapes) }
def __init__(self, observation_spaces_dict: Dict[Union[str, int], spaces.Dict], agent_counts_dict: Dict[StepKeyType, int], networks: CollectionOfConfigType): super().__init__(observation_spaces_dict, agent_counts_dict) # initialize critic model_registry = Factory(base_type=nn.Module) networks = list_to_dict(networks) self._critics = dict() for idx, (key, net_config) in enumerate(networks.items()): step_obs_shapes = self._obs_shapes[key] if idx > 0: step_obs_shapes = { **step_obs_shapes, self.prev_value_key: self.prev_value_shape } self._critics[key] = model_registry.instantiate( networks[key], obs_shapes=step_obs_shapes)
def __init__(self, observation_spaces_dict: Dict[StepKeyType, spaces.Dict], agent_counts_dict: Dict[StepKeyType, int], networks: ConfigType, stack_observations: bool): super().__init__(observation_spaces_dict, agent_counts_dict) assert len(networks) == 1 self.stack_observations = stack_observations network = networks[0] obs_shapes_flat = self._obs_shapes if self.stack_observations: obs_shapes_flat = stacked_shapes(obs_shapes_flat, self._agent_counts_dict) obs_shapes_flat = flat_structured_shapes(obs_shapes_flat) self._obs_shapes = {0: obs_shapes_flat} # initialize critic model_registry = Factory(base_type=nn.Module) self._critics = { 0: model_registry.instantiate(network, obs_shapes=obs_shapes_flat) }
def test_raises_exception_on_invalid_registry_value(): with pytest.raises(ImportError): registry = Factory(base_type=DummyObservationConversion) registry.instantiate(config={"_target_": "wrong_key"})
def test_raises_exception_on_invalid_type(): with pytest.raises(AssertionError): registry = Factory(base_type=DummyObservationConversion) registry.instantiate(config=DictActionConversion())
def test_returns_if_arg_already_instantiated(): obs_conv = CustomDummyObservationConversion(attr=1) registry = Factory(base_type=DummyObservationConversion) obj = registry.instantiate(config=obs_conv) assert obj == obj