Пример #1
0
  def __init__(self,
               time_step_spec,
               action_spec,
               policy_state_spec=(),
               info_spec=(),
               clip=True,
               name=None):
    """Initialization of Base class.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps. Usually
        provided by the user to the subclass.
      action_spec: A nest of BoundedTensorSpec representing the actions. Usually
        provided by the user to the subclass.
      policy_state_spec: A nest of TensorSpec representing the policy_state.
        Provided by the subclass, not directly by the user.
      info_spec: A nest of TensorSpec representing the policy info. Provided by
        the subclass, not directly by the user.
      clip: Whether to clip actions to spec before returning them.  Default
        True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped
        continuous actions for training.
      name: A name for this module. Defaults to the class name.
    """
    super(Base, self).__init__(name=name)
    common.assert_members_are_not_overridden(base_cls=Base, instance=self)

    self._time_step_spec = time_step_spec
    self._action_spec = action_spec
    self._policy_state_spec = policy_state_spec
    self._info_spec = info_spec
    self._setup_specs()
    self._clip = clip
Пример #2
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 policy_state_spec=(),
                 info_spec=(),
                 name=None):
        """Initialization of Base class.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
        Usually provided by the user to the subclass.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Usually provided by the user to the subclass.
      policy_state_spec: A nest of TensorSpec representing the policy_state.
        Provided by the subclass, not directly by the user.
      info_spec: A nest of TensorSpec representing the policy info.
        Provided by the subclass, not directly by the user.
      name: A name for this module. Defaults to the class name.
    """
        super(Base, self).__init__(name=name)
        common.assert_members_are_not_overridden(base_cls=Base, instance=self)

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy_state_spec = policy_state_spec
        self._info_spec = info_spec
        self._setup_specs()
Пример #3
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 policy_state_spec=(),
                 info_spec=(),
                 clip=True,
                 emit_log_probability=False,
                 automatic_state_reset=True,
                 name=None):
        """Initialization of Base class.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps. Usually
        provided by the user to the subclass.
      action_spec: A nest of BoundedTensorSpec representing the actions. Usually
        provided by the user to the subclass.
      policy_state_spec: A nest of TensorSpec representing the policy_state.
        Provided by the subclass, not directly by the user.
      info_spec: A nest of TensorSpec representing the policy info. Provided by
        the subclass, not directly by the user.
      clip: Whether to clip actions to spec before returning them.  Default
        True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped
        continuous actions for training.
      emit_log_probability: Emit log-probabilities of actions, if supported. If
        True, policy_step.info will have CommonFields.LOG_PROBABILITY set.
        Please consult utility methods provided in policy_step for setting and
        retrieving these. When working with custom policies, either provide a
        dictionary info_spec or a namedtuple with the field 'log_probability'.
      automatic_state_reset:  If `True`, then `get_initial_policy_state` is used
        to clear state in `action()` and `distribution()` for for time steps
        where `time_step.is_first()`.
      name: A name for this module. Defaults to the class name.
    """
        super(Base, self).__init__(name=name)
        common.assert_members_are_not_overridden(base_cls=Base, instance=self)

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy_state_spec = policy_state_spec
        self._emit_log_probability = emit_log_probability
        if emit_log_probability:
            log_probability_spec = tensor_spec.BoundedTensorSpec(
                shape=(),
                dtype=tf.float32,
                maximum=0,
                minimum=-float('inf'),
                name='log_probability')
            log_probability_spec = tf.nest.map_structure(
                lambda _: log_probability_spec, action_spec)
            info_spec = policy_step.set_log_probability(
                info_spec, log_probability_spec)

        self._info_spec = info_spec
        self._setup_specs()
        self._clip = clip
        self._action_fn = common.function_in_tf1()(self._action)
        self._automatic_state_reset = automatic_state_reset
Пример #4
0
  def __init__(self,
               time_step_spec,
               action_spec,
               policy,
               collect_policy,
               train_sequence_length,
               debug_summaries=False,
               summarize_grads_and_vars=False,
               train_step_counter=None):
    """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps. Provided by
        the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.Base` representing the Agent's current
        policy.
      collect_policy: An instance of `tf_policy.Base` representing the Agent's
        current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
    """
    common.assert_members_are_not_overridden(base_cls=TFAgent, instance=self)
    if not isinstance(time_step_spec, ts.TimeStep):
      raise ValueError(
          "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
          .format(type(time_step_spec)))

    self._time_step_spec = time_step_spec
    self._action_spec = action_spec
    self._policy = policy
    self._collect_policy = collect_policy
    self._train_sequence_length = train_sequence_length
    self._debug_summaries = debug_summaries
    self._summarize_grads_and_vars = summarize_grads_and_vars
    if train_step_counter is None:
      train_step_counter = tf.compat.v1.train.get_or_create_global_step()
    self._train_step_counter = train_step_counter
    self._train_fn = common.function_in_tf1()(self._train)
    self._initialize_fn = common.function_in_tf1()(self._initialize)
Пример #5
0
    def __init__(self, handle_auto_reset: bool = False):
        """Base class for Python RL environments.

    Args:
      handle_auto_reset: When `True` the base class will handle auto_reset of
        the Environment.
    """
        self._handle_auto_reset = handle_auto_reset
        self._current_time_step = None
        common.assert_members_are_not_overridden(base_cls=PyEnvironment,
                                                 instance=self,
                                                 denylist=('reset', 'step'))
Пример #6
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedArraySpec,
                 policy_state_spec: types.NestedArraySpec = (),
                 info_spec: types.NestedArraySpec = (),
                 observation_and_action_constraint_splitter: Optional[
                     types.Splitter] = None):
        """Initialization of PyPolicy class.

    Args:
      time_step_spec: A `TimeStep` ArraySpec of the expected time_steps. Usually
        provided by the user to the subclass.
      action_spec: A nest of BoundedArraySpec representing the actions. Usually
        provided by the user to the subclass.
      policy_state_spec: A nest of ArraySpec representing the policy state.
        Provided by the subclass, not directly by the user.
      info_spec: A nest of ArraySpec representing the policy info. Provided by
        the subclass, not directly by the user.
      observation_and_action_constraint_splitter: A function used to process
        observations with action constraints. These constraints can indicate,
        for example, a mask of valid/invalid actions for a given state of the
        environment. The function takes in a full observation and returns a
        tuple consisting of 1) the part of the observation intended as input to
        the network and 2) the constraint. An example
        `observation_and_action_constraint_splitter` could be as simple as: ```
        def observation_and_action_constraint_splitter(observation): return
          observation['network_input'], observation['constraint'] ```
        *Note*: when using `observation_and_action_constraint_splitter`, make
          sure the provided `q_network` is compatible with the network-specific
          half of the output of the
          `observation_and_action_constraint_splitter`. In particular,
          `observation_and_action_constraint_splitter` will be called on the
          observation before passing to the network. If
          `observation_and_action_constraint_splitter` is None, action
          constraints are not applied.
    """
        common.tf_agents_gauge.get_cell('TFAPolicy').set(True)
        common.assert_members_are_not_overridden(base_cls=PyPolicy,
                                                 instance=self)
        self._time_step_spec = tensor_spec.to_array_spec(time_step_spec)
        self._action_spec = tensor_spec.to_array_spec(action_spec)
        # TODO(kbanoop): rename policy_state to state.
        self._policy_state_spec = tensor_spec.to_array_spec(policy_state_spec)
        self._info_spec = tensor_spec.to_array_spec(info_spec)
        self._setup_specs()
        self._observation_and_action_constraint_splitter = (
            observation_and_action_constraint_splitter)
Пример #7
0
  def __init__(self,
               time_step_spec,
               action_spec,
               policy,
               collect_policy,
               train_sequence_length,
               debug_summaries=False,
               summarize_grads_and_vars=False):
    """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps. Provided by
        the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.Base` representing the Agent's current
        policy.
      collect_policy: An instance of `tf_policy.Base` representing the Agent's
        current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
    """
    common.assert_members_are_not_overridden(
        base_cls=BaseV2,
        instance=self,
        allowed_overrides=set(["_initialize", "_train"]))

    self._time_step_spec = time_step_spec
    self._action_spec = action_spec
    self._policy = policy
    self._collect_policy = collect_policy
    self._train_sequence_length = train_sequence_length
    self._debug_summaries = debug_summaries
    self._summarize_grads_and_vars = summarize_grads_and_vars
Пример #8
0
  def __init__(self, time_step_spec, action_spec, policy_state_spec=(),
               info_spec=()):
    """Initialization of Base class.

    Args:
      time_step_spec: A `TimeStep` ArraySpec of the expected time_steps.
        Usually provided by the user to the subclass.
      action_spec: A nest of BoundedArraySpec representing the actions.
        Usually provided by the user to the subclass.
      policy_state_spec: A nest of ArraySpec representing the policy state.
        Provided by the subclass, not directly by the user.
      info_spec: A nest of ArraySpec representing the policy info.
        Provided by the subclass, not directly by the user.
    """
    common.assert_members_are_not_overridden(base_cls=Base, instance=self)
    self._time_step_spec = time_step_spec
    self._action_spec = action_spec
    # TODO(kbanoop): rename policy_state to state.
    self._policy_state_spec = policy_state_spec
    self._info_spec = info_spec
    self._setup_specs()
Пример #9
0
 def __init__(self):
     self._current_time_step = None
     common.assert_members_are_not_overridden(base_cls=PyEnvironment,
                                              instance=self,
                                              black_list=('reset', 'step'))
Пример #10
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 policy_state_spec=(),
                 info_spec=(),
                 clip=True,
                 emit_log_probability=False,
                 automatic_state_reset=True,
                 observation_and_action_constraint_splitter=None,
                 validate_args=True,
                 name=None):
        """Initialization of TFPolicy class.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps. Usually
        provided by the user to the subclass.
      action_spec: A nest of BoundedTensorSpec representing the actions. Usually
        provided by the user to the subclass.
      policy_state_spec: A nest of TensorSpec representing the policy_state.
        Provided by the subclass, not directly by the user.
      info_spec: A nest of TensorSpec representing the policy info. Provided by
        the subclass, not directly by the user.
      clip: Whether to clip actions to spec before returning them.  Default
        True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped
        continuous actions for training.
      emit_log_probability: Emit log-probabilities of actions, if supported. If
        True, policy_step.info will have CommonFields.LOG_PROBABILITY set.
        Please consult utility methods provided in policy_step for setting and
        retrieving these. When working with custom policies, either provide a
        dictionary info_spec or a namedtuple with the field 'log_probability'.
      automatic_state_reset:  If `True`, then `get_initial_policy_state` is used
        to clear state in `action()` and `distribution()` for for time steps
        where `time_step.is_first()`.
      observation_and_action_constraint_splitter: A function used to process
        observations with action constraints. These constraints can indicate,
        for example, a mask of valid/invalid actions for a given state of the
        environment. The function takes in a full observation and returns a
        tuple consisting of 1) the part of the observation intended as input to
        the network and 2) the constraint. An example
        `observation_and_action_constraint_splitter` could be as simple as: ```
        def observation_and_action_constraint_splitter(observation): return
          observation['network_input'], observation['constraint'] ```
        *Note*: when using `observation_and_action_constraint_splitter`, make
          sure the provided `q_network` is compatible with the network-specific
          half of the output of the
          `observation_and_action_constraint_splitter`. In particular,
          `observation_and_action_constraint_splitter` will be called on the
          observation before passing to the network. If
          `observation_and_action_constraint_splitter` is None, action
          constraints are not applied.
      validate_args: Python bool.  Whether to verify inputs to, and outputs of,
        functions like `action` and `distribution` against spec structures,
        dtypes, and shapes.

        Research code may prefer to set this value to `False` to allow iterating
        on input and output structures without being hamstrung by overly
        rigid checking (at the cost of harder-to-debug errors).

        See also `TFAgent.validate_args`.
      name: A name for this module. Defaults to the class name.
    """
        super(TFPolicy, self).__init__(name=name)
        common.check_tf1_allowed()
        common.tf_agents_gauge.get_cell('TFAPolicy').set(True)
        common.assert_members_are_not_overridden(base_cls=TFPolicy,
                                                 instance=self)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise ValueError(
                'The `time_step_spec` must be an instance of `TimeStep`, but is `{}`.'
                .format(type(time_step_spec)))

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy_state_spec = policy_state_spec
        self._emit_log_probability = emit_log_probability
        self._validate_args = validate_args

        if emit_log_probability:
            log_probability_spec = tensor_spec.BoundedTensorSpec(
                shape=(),
                dtype=tf.float32,
                maximum=0,
                minimum=-float('inf'),
                name='log_probability')
            log_probability_spec = tf.nest.map_structure(
                lambda _: log_probability_spec, action_spec)
            info_spec = policy_step.set_log_probability(
                info_spec, log_probability_spec)

        self._info_spec = info_spec
        self._setup_specs()
        self._clip = clip
        self._action_fn = common.function_in_tf1()(self._action)
        self._automatic_state_reset = automatic_state_reset
        self._observation_and_action_constraint_splitter = (
            observation_and_action_constraint_splitter)
Пример #11
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 policy,
                 collect_policy,
                 train_sequence_length,
                 num_outer_dims=2,
                 training_data_spec=None,
                 train_argspec=None,
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 enable_summaries=True,
                 train_step_counter=None):
        """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A nest of tf.TypeSpec representing the time_steps.
        Provided by the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.Base` representing the Agent's current
        policy.
      collect_policy: An instance of `tf_policy.Base` representing the Agent's
        current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      num_outer_dims: The number of outer dimensions for the agent. Must be
        either 1 or 2. If 2, training will require both a batch_size and time
        dimension on every Tensor; if 1, training will require only a batch_size
        outer dimension.
      training_data_spec: A nest of TensorSpec specifying the structure of data
        the train() function expects. If None, defaults to the trajectory_spec
        of the collect_policy.
      train_argspec: (Optional) Describes additional supported arguments
        to the `train` call.  This must be a `dict` mapping strings to nests
        of specs.  Overriding the `experience` arg is also supported.

        Some algorithms require additional arguments to the `train()` call, and
        while TF-Agents encourages most of these to be provided in the
        `policy_info` / `info` field of `experience`, sometimes the extra
        information doesn't fit well, i.e., when it doesn't come from the
        policy.

        **NOTE** kwargs will not have their outer dimensions validated.
        In particular, `train_sequence_length` is ignored for these inputs,
        and they may have any, or inconsistent, batch/time dimensions; only
        their inner shape dimensions are checked against `train_argspec`.

        Below is an example:

        ```python
        class MyAgent(TFAgent):
          def __init__(self, counterfactual_training, ...):
             collect_policy = ...
             train_argspec = None
             if counterfactual_training:
               train_argspec = dict(
                  counterfactual=collect_policy.trajectory_spec)
             super(...).__init__(
               ...
               train_argspec=train_argspec)

        my_agent = MyAgent(...)

        for ...:
          experience, counterfactual = next(experience_and_counterfactual_iter)
          loss_info = my_agent.train(experience, counterfactual=counterfactual)
        ```
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      enable_summaries: A bool; if false, subclasses should not gather any
        summaries (debug or otherwise); subclasses should gate *all* summaries
        using either `summaries_enabled`, `debug_summaries`, or
        `summarize_grads_and_vars` properties.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.

    Raises:
      TypeError: If `train_argspec` is not a `dict`.
      ValueError: If `train_argspec` has the keys `experience` or `weights`.
      TypeError: If any leaf nodes in `train_argspec` values are not
        subclasses of `tf.TypeSpec`.
      ValueError: If `time_step_spec` is not an instance of `ts.TimeStep`.
      ValueError: If `num_outer_dims` is not in [1, 2].
    """
        def _each_isinstance(spec, spec_types):
            """Checks if each element of `spec` is instance of any of `spec_types`."""
            return all(
                [isinstance(s, spec_types) for s in tf.nest.flatten(spec)])

        if not _each_isinstance(time_step_spec, tf.TypeSpec):
            raise TypeError(
                "time_step_spec has to contain TypeSpec (TensorSpec, "
                "SparseTensorSpec, etc) objects, but received: {}".format(
                    time_step_spec))

        if not _each_isinstance(action_spec, tensor_spec.BoundedTensorSpec):
            raise TypeError(
                "action_spec has to contain BoundedTensorSpec objects, but received: "
                "{}".format(action_spec))

        common.check_tf1_allowed()
        common.tf_agents_gauge.get_cell("TFAgent").set(True)
        common.assert_members_are_not_overridden(base_cls=TFAgent,
                                                 instance=self)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise ValueError(
                "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
                .format(type(time_step_spec)))

        if num_outer_dims not in [1, 2]:
            raise ValueError("num_outer_dims must be in [1, 2].")

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy = policy
        self._collect_policy = collect_policy
        self._train_sequence_length = train_sequence_length
        self._num_outer_dims = num_outer_dims
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._enable_summaries = enable_summaries
        self._training_data_spec = training_data_spec
        if train_argspec is None:
            train_argspec = {}
        else:
            if not isinstance(train_argspec, dict):
                raise TypeError(
                    "train_argspec must be a dict, but saw: {}".format(
                        train_argspec))
            train_argspec = dict(train_argspec)  # Create a local copy.
            if "weights" in train_argspec or "experience" in train_argspec:
                raise ValueError(
                    "train_argspec must not override 'weights' or "
                    "'experience' keys, but saw: {}".format(train_argspec))
            if not all(
                    isinstance(x, tf.TypeSpec)
                    for x in tf.nest.flatten(train_argspec)):
                raise TypeError(
                    "train_argspec contains non-TensorSpec objects: {}".format(
                        train_argspec))
        self._train_argspec = train_argspec
        if train_step_counter is None:
            train_step_counter = tf.compat.v1.train.get_or_create_global_step()
        self._train_step_counter = train_step_counter
        self._train_fn = common.function_in_tf1()(self._train)
        self._initialize_fn = common.function_in_tf1()(self._initialize)
Пример #12
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 policy,
                 collect_policy,
                 train_sequence_length,
                 num_outer_dims=2,
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 enable_summaries=True,
                 train_step_counter=None):
        """Meant to be called by subclass constructors.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps. Provided by
        the user.
      action_spec: A nest of BoundedTensorSpec representing the actions.
        Provided by the user.
      policy: An instance of `tf_policy.Base` representing the Agent's current
        policy.
      collect_policy: An instance of `tf_policy.Base` representing the Agent's
        current data collection policy (used to set `self.step_spec`).
      train_sequence_length: A python integer or `None`, signifying the number
        of time steps required from tensors in `experience` as passed to
        `train()`.  All tensors in `experience` will be shaped `[B, T, ...]` but
        for certain agents, `T` should be fixed.  For example, DQN requires
        transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
        this value to 2.  For agents that don't care, or which can handle `T`
        unknown at graph build time (i.e. most RNN-based agents), set this
        argument to `None`.
      num_outer_dims: The number of outer dimensions for the agent. Must be
        either 1 or 2. If 2, training will require both a batch_size and time
        dimension on every Tensor; if 1, training will require only a batch_size
        outer dimension.
      debug_summaries: A bool; if true, subclasses should gather debug
        summaries.
      summarize_grads_and_vars: A bool; if true, subclasses should additionally
        collect gradient and variable summaries.
      enable_summaries: A bool; if false, subclasses should not gather any
        summaries (debug or otherwise); subclasses should gate *all* summaries
        using either `summaries_enabled`, `debug_summaries`, or
        `summarize_grads_and_vars` properties.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.

    Raises:
      ValueError: If `time_step_spec` is not an instance of `ts.TimeStep`.
      ValueError: If `num_outer_dims` is not in [1, 2].
    """
        common.tf_agents_gauge.get_cell("TFAgent").set(True)
        common.assert_members_are_not_overridden(base_cls=TFAgent,
                                                 instance=self)
        if not isinstance(time_step_spec, ts.TimeStep):
            raise ValueError(
                "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`."
                .format(type(time_step_spec)))

        if num_outer_dims not in [1, 2]:
            raise ValueError("num_outer_dims must be in [1, 2].")

        self._time_step_spec = time_step_spec
        self._action_spec = action_spec
        self._policy = policy
        self._collect_policy = collect_policy
        self._train_sequence_length = train_sequence_length
        self._num_outer_dims = num_outer_dims
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._enable_summaries = enable_summaries
        if train_step_counter is None:
            train_step_counter = tf.compat.v1.train.get_or_create_global_step()
        self._train_step_counter = train_step_counter
        self._train_fn = common.function_in_tf1()(self._train)
        self._initialize_fn = common.function_in_tf1()(self._initialize)
 def __init__(self, white_list=(), black_list=()):
     common.assert_members_are_not_overridden(base_cls=Base,
                                              instance=self,
                                              white_list=white_list,
                                              black_list=black_list)
Пример #14
0
 def __init__(self):
     common.assert_members_are_not_overridden(base_cls=Base,
                                              instance=self,
                                              black_list=('reset', 'step'))
 def __init__(self, allowlist=(), denylist=()):
     common.assert_members_are_not_overridden(base_cls=Base,
                                              instance=self,
                                              allowlist=allowlist,
                                              denylist=denylist)