Ejemplo n.º 1
0
    def __init__(self,
                 preprocessor_spec,
                 policy_spec,
                 exploration_spec=None,
                 **kwargs):
        """
        Args:
            preprocessor_spec (Union[list,dict,PreprocessorSpec]):
                - A dict if the state from the Env will come in as a ContainerSpace (e.g. Dict). In this case, each
                    each key in this dict specifies, which value in the incoming dict should go through which PreprocessorStack.
                - A list with layer specs.
                - A PreprocessorStack object.

            policy_spec (Union[dict,Policy]): A specification dict for a Policy object or a Policy object directly.

            exploration_spec (Union[dict,Exploration]): A specification dict for an Exploration object or an Exploration
                object directly.
        """
        super(ActorComponent,
              self).__init__(scope=kwargs.pop("scope", "actor-component"),
                             **kwargs)

        self.preprocessor = PreprocessorStack.from_spec(preprocessor_spec)
        self.policy = Policy.from_spec(policy_spec)
        self.num_nn_inputs = self.policy.neural_network.num_inputs
        self.exploration = Exploration.from_spec(exploration_spec)

        self.tuple_merger = ContainerMerger(is_tuple=True,
                                            merge_tuples_into_one=True)
        self.tuple_splitter = ContainerSplitter(
            tuple_length=self.num_nn_inputs)

        self.add_components(self.policy, self.exploration, self.preprocessor,
                            self.tuple_merger, self.tuple_splitter)
Ejemplo n.º 2
0
    def __init__(self,
                 policy_spec,
                 minimum_batch_size=1,
                 maximum_batch_size=1024,
                 timeout_ms=100,
                 scope="dynamic-batching-policy",
                 **kwargs):
        """
        Args:
            policy_spec (Union[Optimizer,dict]): A spec dict to construct the Policy that is wrraped by this
                DynamicBatchingPolicy or a Policy object directly.
            minimum_batch_size (int): The minimum batch size to use. Default: 1.
            maximum_batch_size (int): The maximum batch size to use. Default: 1024
            timeout_ms (int): The time out in ms to use when waiting for items on the queue.
                Default: 100ms.
        """
        super(DynamicBatchingPolicy, self).__init__(
            # 3=states, logits, internal_states
            graph_fn_num_outputs=dict(
                _graph_fn_get_state_values_logits_probabilities_log_probs=5),
            scope=scope,
            **kwargs)

        # The wrapped, backend-specific policy object.
        self.policy = Policy.from_spec(policy_spec)

        # hack: link in case parent components call APIs of the distribution directly
        self.action_adapter = self.policy.action_adapter
        self.distribution = self.policy.distribution
        self.deterministic = True

        # Dynamic batching options.
        self.minimum_batch_size = minimum_batch_size
        self.maximum_batch_size = maximum_batch_size
        self.timeout_ms = timeout_ms

        self.add_components(self.policy)

        # TODO: for now, only define this one API-method as this is the only one used in IMPALA.
        # TODO: Generalize this component so it can wrap arbitrary other components and simulate their API.
        self.define_api_method(
            "get_state_values_logits_probabilities_log_probs",
            self._graph_fn_get_state_values_logits_probabilities_log_probs)