def __init__(self, preprocessor_spec, policy_spec, exploration_spec=None, **kwargs): """ Args: preprocessor_spec (Union[list,dict,PreprocessorSpec]): - A dict if the state from the Env will come in as a ContainerSpace (e.g. Dict). In this case, each each key in this dict specifies, which value in the incoming dict should go through which PreprocessorStack. - A list with layer specs. - A PreprocessorStack object. policy_spec (Union[dict,Policy]): A specification dict for a Policy object or a Policy object directly. exploration_spec (Union[dict,Exploration]): A specification dict for an Exploration object or an Exploration object directly. """ super(ActorComponent, self).__init__(scope=kwargs.pop("scope", "actor-component"), **kwargs) self.preprocessor = PreprocessorStack.from_spec(preprocessor_spec) self.policy = Policy.from_spec(policy_spec) self.num_nn_inputs = self.policy.neural_network.num_inputs self.exploration = Exploration.from_spec(exploration_spec) self.tuple_merger = ContainerMerger(is_tuple=True, merge_tuples_into_one=True) self.tuple_splitter = ContainerSplitter( tuple_length=self.num_nn_inputs) self.add_components(self.policy, self.exploration, self.preprocessor, self.tuple_merger, self.tuple_splitter)
def __init__(self, policy_spec, minimum_batch_size=1, maximum_batch_size=1024, timeout_ms=100, scope="dynamic-batching-policy", **kwargs): """ Args: policy_spec (Union[Optimizer,dict]): A spec dict to construct the Policy that is wrraped by this DynamicBatchingPolicy or a Policy object directly. minimum_batch_size (int): The minimum batch size to use. Default: 1. maximum_batch_size (int): The maximum batch size to use. Default: 1024 timeout_ms (int): The time out in ms to use when waiting for items on the queue. Default: 100ms. """ super(DynamicBatchingPolicy, self).__init__( # 3=states, logits, internal_states graph_fn_num_outputs=dict( _graph_fn_get_state_values_logits_probabilities_log_probs=5), scope=scope, **kwargs) # The wrapped, backend-specific policy object. self.policy = Policy.from_spec(policy_spec) # hack: link in case parent components call APIs of the distribution directly self.action_adapter = self.policy.action_adapter self.distribution = self.policy.distribution self.deterministic = True # Dynamic batching options. self.minimum_batch_size = minimum_batch_size self.maximum_batch_size = maximum_batch_size self.timeout_ms = timeout_ms self.add_components(self.policy) # TODO: for now, only define this one API-method as this is the only one used in IMPALA. # TODO: Generalize this component so it can wrap arbitrary other components and simulate their API. self.define_api_method( "get_state_values_logits_probabilities_log_probs", self._graph_fn_get_state_values_logits_probabilities_log_probs)