def __init__(self, input_tensor_spec=None, state_spec=(), name=None): """Creates an instance of `Network`. Args: input_tensor_spec: A nest of `tf.TypeSpec` representing the input observations. Optional. If not provided, `create_variables()` will fail unless a spec is provided. state_spec: A nest of `tensor_spec.TensorSpec` representing the state needed by the network. Default is `()`, which means no state. name: (Optional.) A string representing the name of the network. """ # Disable autocast because it may convert bfloats to other types, breaking # our spec checks. super(Network, self).__init__(name=name, autocast=False) common.check_tf1_allowed() # Required for summary() to work. self._is_graph_network = False self._input_tensor_spec = (tensor_spec.from_spec(input_tensor_spec) if input_tensor_spec is not None else None) # NOTE(ebrevdo): Would have preferred to call this output_tensor_spec, but # looks like keras.Layer already reserves that one. self._network_output_spec = None self._state_spec = tensor_spec.from_spec(state_spec)
def __init__( self, env: tf_environment.TFEnvironment, policy: tf_policy.TFPolicy, observers: Sequence[Callable[[trajectory.Trajectory], Any]], transition_observers: Optional[Sequence[Callable[[trajectory.Transition], Any]]] = None, max_steps: Optional[types.Int] = None, max_episodes: Optional[types.Int] = None, disable_tf_function: bool = False): """A driver that runs a TF policy in a TF environment. **Note** about bias when using batched environments with `max_episodes`: When using `max_episodes != None`, a `run` step "finishes" when `max_episodes` have been completely collected (hit a boundary). When used in conjunction with environments that have variable-length episodes, this skews the distribution of collected episodes' lengths: short episodes are seen more frequently than long ones. As a result, running an `env` of `N > 1` batched environments with `max_episodes >= 1` is not the same as running an env with `1` environment with `max_episodes >= 1`. Args: env: A tf_environment.Base environment. policy: A tf_policy.TFPolicy policy. observers: A list of observers that are notified after every step in the environment. Each observer is a callable(trajectory.Trajectory). transition_observers: A list of observers that are updated after every step in the environment. Each observer is a callable((TimeStep, PolicyStep, NextTimeStep)). The transition is shaped just as trajectories are for regular observers. max_steps: Optional maximum number of steps for each run() call. For batched or parallel environments, this is the maximum total number of steps summed across all environments. Also see below. Default: 0. max_episodes: Optional maximum number of episodes for each run() call. For batched or parallel environments, this is the maximum total number of episodes summed across all environments. At least one of max_steps or max_episodes must be provided. If both are set, run() terminates when at least one of the conditions is satisfied. Default: 0. disable_tf_function: If True the use of tf.function for the run method is disabled. Raises: ValueError: If both max_steps and max_episodes are None. """ common.check_tf1_allowed() max_steps = max_steps or 0 max_episodes = max_episodes or 0 if max_steps < 1 and max_episodes < 1: raise ValueError( 'Either `max_steps` or `max_episodes` should be greater than 0.') super(TFDriver, self).__init__(env, policy, observers, transition_observers) self._max_steps = max_steps or np.inf self._max_episodes = max_episodes or np.inf if not disable_tf_function: self.run = common.function(self.run, autograph=True)
def __init__(self, env, agent, adversary_agent, adversary_env, env_metrics=None, collect=True, disable_tf_function=False, debug=False, combined_population=False, flexible_protagonist=False): """Runs the environment adversary and agents to collect episodes. Args: env: A tf_environment.Base environment. agent: An AgentTrainPackage for the main learner agent. adversary_agent: An AgentTrainPackage for the second agent, the adversary's ally. This can be None if using an unconstrained adversary environment. adversary_env: An AgentTrainPackage for the agent that controls the environment, learning to set parameters of the environment to decrease the agent's score relative to the adversary_agent. Can be None if using domain randomization. env_metrics: Global environment metrics to track (such as path length). collect: True if collecting episodes for training, otherwise eval. disable_tf_function: If True the use of tf.function for the run method is disabled. debug: If True, outputs informative logging statements. combined_population: If True, the entire population of protagonists plays each generated environment, and regret is the calc'd as the difference between the max of the population and the average (there are no explicit antagonists). flexible_protagonist: Which agent plays the role of protagonist in calculating the regret depends on which has the lowest score. """ common.check_tf1_allowed() self.debug = debug self.total_episodes_collected = 0 if not disable_tf_function: self.run = common.function(self.run, autograph=True) self.run_agent = common.function(self.run_agent, autograph=True) self.env_metrics = env_metrics self.collect = collect self.env = env self.agent = agent self.adversary_agent = adversary_agent self.adversary_env = adversary_env self.combined_population = combined_population self.flexible_protagonist = flexible_protagonist
def __init__(self, env: tf_environment.TFEnvironment, policy: tf_policy.TFPolicy, observers: Sequence[Callable[[trajectory.Trajectory], Any]], transition_observers: Optional[Sequence[Callable[ [trajectory.Transition], Any]]] = None, max_steps: Optional[types.Int] = None, max_episodes: Optional[types.Int] = None, disable_tf_function: bool = False): """A driver that runs a TF policy in a TF environment. Args: env: A tf_environment.Base environment. policy: A tf_policy.TFPolicy policy. observers: A list of observers that are notified after every step in the environment. Each observer is a callable(trajectory.Trajectory). transition_observers: A list of observers that are updated after every step in the environment. Each observer is a callable((TimeStep, PolicyStep, NextTimeStep)). The transition is shaped just as trajectories are for regular observers. max_steps: Optional maximum number of steps for each run() call. For batched or parallel environments, this is the maximum total number of steps summed across all environments. Also see below. Default: 0. max_episodes: Optional maximum number of episodes for each run() call. For batched or parallel environments, this is the maximum total number of episodes summed across all environments. At least one of max_steps or max_episodes must be provided. If both are set, run() terminates when at least one of the conditions is satisfied. Default: 0. disable_tf_function: If True the use of tf.function for the run method is disabled. Raises: ValueError: If both max_steps and max_episodes are None. """ common.check_tf1_allowed() max_steps = max_steps or 0 max_episodes = max_episodes or 0 if max_steps < 1 and max_episodes < 1: raise ValueError( 'Either `max_steps` or `max_episodes` should be greater than 0.' ) super(TFDriver, self).__init__(env, policy, observers, transition_observers) self._max_steps = max_steps or np.inf self._max_episodes = max_episodes or np.inf if not disable_tf_function: self.run = common.function(self.run, autograph=True)
def __init__(self, data_spec, capacity, stateful_dataset=False): """Initializes the replay buffer. Args: data_spec: A spec or a list/tuple/nest of specs describing a single item that can be stored in this buffer capacity: number of elements that the replay buffer can hold. stateful_dataset: whether the dataset contains stateful ops or not. """ super(ReplayBuffer, self).__init__() common.check_tf1_allowed() self._data_spec = data_spec self._capacity = capacity self._stateful_dataset = stateful_dataset
def __init__(self, input_tensor_spec, state_spec, name): """Creates an instance of `Network`. Args: input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the input observations. state_spec: A nest of `tensor_spec.TensorSpec` representing the state needed by the network. Use () if none. name: A string representing the name of the network. """ super(Network, self).__init__(name=name) common.check_tf1_allowed() self._input_tensor_spec = input_tensor_spec self._state_spec = state_spec
def __init__(self, input_tensor_spec, state_spec, name=None): """Creates an instance of `Network`. Args: input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the input observations. state_spec: A nest of `tensor_spec.TensorSpec` representing the state needed by the network. Use () if none. name: (Optional.) A string representing the name of the network. """ super(Network, self).__init__(name=name) common.check_tf1_allowed() # Required for summary() to work. self._is_graph_network = False self._input_tensor_spec = input_tensor_spec self._state_spec = state_spec
def __init__(self, input_tensor_spec=None, state_spec=(), name=None): """Creates an instance of `Network`. Args: input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the input observations. Optional. If not provided, `create_variables()` will fail unless a spec is provided. state_spec: A nest of `tensor_spec.TensorSpec` representing the state needed by the network. Default is `()`, which means no state. name: (Optional.) A string representing the name of the network. """ super(Network, self).__init__(name=name) common.check_tf1_allowed() # Required for summary() to work. self._is_graph_network = False self._input_tensor_spec = input_tensor_spec self._state_spec = state_spec
def __init__(self, input_tensor_spec=None, state_spec=(), name=None): """Creates an instance of `Network`. Args: input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the input observations. Optional. If not provided, `create_variables()` will fail unless a spec is provided. state_spec: A nest of `tensor_spec.TensorSpec` representing the state needed by the network. Default is `()`, which means no state. name: (Optional.) A string representing the name of the network. """ super(Network, self).__init__(name=name) common.check_tf1_allowed() # Required for summary() to work. self._is_graph_network = False self._input_tensor_spec = input_tensor_spec # NOTE(ebrevdo): Would have preferred to call this output_tensor_spec, but # looks like keras.Layer already reserves that one. self._network_output_spec = None self._state_spec = state_spec
def __init__(self, time_step_spec, action_spec, policy_state_spec=(), info_spec=(), clip=True, emit_log_probability=False, automatic_state_reset=True, observation_and_action_constraint_splitter=None, validate_args=True, name=None): """Initialization of TFPolicy class. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. Usually provided by the user to the subclass. action_spec: A nest of BoundedTensorSpec representing the actions. Usually provided by the user to the subclass. policy_state_spec: A nest of TensorSpec representing the policy_state. Provided by the subclass, not directly by the user. info_spec: A nest of TensorSpec representing the policy info. Provided by the subclass, not directly by the user. clip: Whether to clip actions to spec before returning them. Default True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped continuous actions for training. emit_log_probability: Emit log-probabilities of actions, if supported. If True, policy_step.info will have CommonFields.LOG_PROBABILITY set. Please consult utility methods provided in policy_step for setting and retrieving these. When working with custom policies, either provide a dictionary info_spec or a namedtuple with the field 'log_probability'. automatic_state_reset: If `True`, then `get_initial_policy_state` is used to clear state in `action()` and `distribution()` for for time steps where `time_step.is_first()`. observation_and_action_constraint_splitter: A function used to process observations with action constraints. These constraints can indicate, for example, a mask of valid/invalid actions for a given state of the environment. The function takes in a full observation and returns a tuple consisting of 1) the part of the observation intended as input to the network and 2) the constraint. An example `observation_and_action_constraint_splitter` could be as simple as: ``` def observation_and_action_constraint_splitter(observation): return observation['network_input'], observation['constraint'] ``` *Note*: when using `observation_and_action_constraint_splitter`, make sure the provided `q_network` is compatible with the network-specific half of the output of the `observation_and_action_constraint_splitter`. In particular, `observation_and_action_constraint_splitter` will be called on the observation before passing to the network. If `observation_and_action_constraint_splitter` is None, action constraints are not applied. validate_args: Python bool. Whether to verify inputs to, and outputs of, functions like `action` and `distribution` against spec structures, dtypes, and shapes. Research code may prefer to set this value to `False` to allow iterating on input and output structures without being hamstrung by overly rigid checking (at the cost of harder-to-debug errors). See also `TFAgent.validate_args`. name: A name for this module. Defaults to the class name. """ super(TFPolicy, self).__init__(name=name) common.check_tf1_allowed() common.tf_agents_gauge.get_cell('TFAPolicy').set(True) common.assert_members_are_not_overridden(base_cls=TFPolicy, instance=self) if not isinstance(time_step_spec, ts.TimeStep): raise ValueError( 'The `time_step_spec` must be an instance of `TimeStep`, but is `{}`.' .format(type(time_step_spec))) self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy_state_spec = policy_state_spec self._emit_log_probability = emit_log_probability self._validate_args = validate_args if emit_log_probability: log_probability_spec = tensor_spec.BoundedTensorSpec( shape=(), dtype=tf.float32, maximum=0, minimum=-float('inf'), name='log_probability') log_probability_spec = tf.nest.map_structure( lambda _: log_probability_spec, action_spec) info_spec = policy_step.set_log_probability( info_spec, log_probability_spec) self._info_spec = info_spec self._setup_specs() self._clip = clip self._action_fn = common.function_in_tf1()(self._action) self._automatic_state_reset = automatic_state_reset self._observation_and_action_constraint_splitter = ( observation_and_action_constraint_splitter)
def __init__(self, time_step_spec, action_spec, policy, collect_policy, train_sequence_length, num_outer_dims=2, training_data_spec=None, train_argspec=None, debug_summaries=False, summarize_grads_and_vars=False, enable_summaries=True, train_step_counter=None): """Meant to be called by subclass constructors. Args: time_step_spec: A nest of tf.TypeSpec representing the time_steps. Provided by the user. action_spec: A nest of BoundedTensorSpec representing the actions. Provided by the user. policy: An instance of `tf_policy.Base` representing the Agent's current policy. collect_policy: An instance of `tf_policy.Base` representing the Agent's current data collection policy (used to set `self.step_spec`). train_sequence_length: A python integer or `None`, signifying the number of time steps required from tensors in `experience` as passed to `train()`. All tensors in `experience` will be shaped `[B, T, ...]` but for certain agents, `T` should be fixed. For example, DQN requires transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set this value to 2. For agents that don't care, or which can handle `T` unknown at graph build time (i.e. most RNN-based agents), set this argument to `None`. num_outer_dims: The number of outer dimensions for the agent. Must be either 1 or 2. If 2, training will require both a batch_size and time dimension on every Tensor; if 1, training will require only a batch_size outer dimension. training_data_spec: A nest of TensorSpec specifying the structure of data the train() function expects. If None, defaults to the trajectory_spec of the collect_policy. train_argspec: (Optional) Describes additional supported arguments to the `train` call. This must be a `dict` mapping strings to nests of specs. Overriding the `experience` arg is also supported. Some algorithms require additional arguments to the `train()` call, and while TF-Agents encourages most of these to be provided in the `policy_info` / `info` field of `experience`, sometimes the extra information doesn't fit well, i.e., when it doesn't come from the policy. **NOTE** kwargs will not have their outer dimensions validated. In particular, `train_sequence_length` is ignored for these inputs, and they may have any, or inconsistent, batch/time dimensions; only their inner shape dimensions are checked against `train_argspec`. Below is an example: ```python class MyAgent(TFAgent): def __init__(self, counterfactual_training, ...): collect_policy = ... train_argspec = None if counterfactual_training: train_argspec = dict( counterfactual=collect_policy.trajectory_spec) super(...).__init__( ... train_argspec=train_argspec) my_agent = MyAgent(...) for ...: experience, counterfactual = next(experience_and_counterfactual_iter) loss_info = my_agent.train(experience, counterfactual=counterfactual) ``` debug_summaries: A bool; if true, subclasses should gather debug summaries. summarize_grads_and_vars: A bool; if true, subclasses should additionally collect gradient and variable summaries. enable_summaries: A bool; if false, subclasses should not gather any summaries (debug or otherwise); subclasses should gate *all* summaries using either `summaries_enabled`, `debug_summaries`, or `summarize_grads_and_vars` properties. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. Raises: TypeError: If `train_argspec` is not a `dict`. ValueError: If `train_argspec` has the keys `experience` or `weights`. TypeError: If any leaf nodes in `train_argspec` values are not subclasses of `tf.TypeSpec`. ValueError: If `time_step_spec` is not an instance of `ts.TimeStep`. ValueError: If `num_outer_dims` is not in [1, 2]. """ def _each_isinstance(spec, spec_types): """Checks if each element of `spec` is instance of any of `spec_types`.""" return all( [isinstance(s, spec_types) for s in tf.nest.flatten(spec)]) if not _each_isinstance(time_step_spec, tf.TypeSpec): raise TypeError( "time_step_spec has to contain TypeSpec (TensorSpec, " "SparseTensorSpec, etc) objects, but received: {}".format( time_step_spec)) if not _each_isinstance(action_spec, tensor_spec.BoundedTensorSpec): raise TypeError( "action_spec has to contain BoundedTensorSpec objects, but received: " "{}".format(action_spec)) common.check_tf1_allowed() common.tf_agents_gauge.get_cell("TFAgent").set(True) common.assert_members_are_not_overridden(base_cls=TFAgent, instance=self) if not isinstance(time_step_spec, ts.TimeStep): raise ValueError( "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`." .format(type(time_step_spec))) if num_outer_dims not in [1, 2]: raise ValueError("num_outer_dims must be in [1, 2].") self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy = policy self._collect_policy = collect_policy self._train_sequence_length = train_sequence_length self._num_outer_dims = num_outer_dims self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._enable_summaries = enable_summaries self._training_data_spec = training_data_spec if train_argspec is None: train_argspec = {} else: if not isinstance(train_argspec, dict): raise TypeError( "train_argspec must be a dict, but saw: {}".format( train_argspec)) train_argspec = dict(train_argspec) # Create a local copy. if "weights" in train_argspec or "experience" in train_argspec: raise ValueError( "train_argspec must not override 'weights' or " "'experience' keys, but saw: {}".format(train_argspec)) if not all( isinstance(x, tf.TypeSpec) for x in tf.nest.flatten(train_argspec)): raise TypeError( "train_argspec contains non-TensorSpec objects: {}".format( train_argspec)) self._train_argspec = train_argspec if train_step_counter is None: train_step_counter = tf.compat.v1.train.get_or_create_global_step() self._train_step_counter = train_step_counter self._train_fn = common.function_in_tf1()(self._train) self._initialize_fn = common.function_in_tf1()(self._initialize)
def __init__(self, name, prefix='Metrics'): super(TFStepMetric, self).__init__(name) common.check_tf1_allowed() self._prefix = prefix
def __init__(self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, policy: tf_policy.TFPolicy, collect_policy: tf_policy.TFPolicy, train_sequence_length: Optional[int], num_outer_dims: int = 2, training_data_spec: Optional[types.NestedTensorSpec] = None, debug_summaries: bool = False, summarize_grads_and_vars: bool = False, enable_summaries: bool = True, train_step_counter: Optional[tf.Variable] = None): """Meant to be called by subclass constructors. Args: time_step_spec: A nest of tf.TypeSpec representing the time_steps. Provided by the user. action_spec: A nest of BoundedTensorSpec representing the actions. Provided by the user. policy: An instance of `tf_policy.TFPolicy` representing the Agent's current policy. collect_policy: An instance of `tf_policy.TFPolicy` representing the Agent's current data collection policy (used to set `self.step_spec`). train_sequence_length: A python integer or `None`, signifying the number of time steps required from tensors in `experience` as passed to `train()`. All tensors in `experience` will be shaped `[B, T, ...]` but for certain agents, `T` should be fixed. For example, DQN requires transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set this value to 2. For agents that don't care, or which can handle `T` unknown at graph build time (i.e. most RNN-based agents), set this argument to `None`. num_outer_dims: The number of outer dimensions for the agent. Must be either 1 or 2. If 2, training will require both a batch_size and time dimension on every Tensor; if 1, training will require only a batch_size outer dimension. training_data_spec: A nest of TensorSpec specifying the structure of data the train() function expects. If None, defaults to the trajectory_spec of the collect_policy. debug_summaries: A bool; if true, subclasses should gather debug summaries. summarize_grads_and_vars: A bool; if true, subclasses should additionally collect gradient and variable summaries. enable_summaries: A bool; if false, subclasses should not gather any summaries (debug or otherwise); subclasses should gate *all* summaries using either `summaries_enabled`, `debug_summaries`, or `summarize_grads_and_vars` properties. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. Raises: ValueError: If `num_outer_dims` is not in `[1, 2]`. """ common.check_tf1_allowed() common.tf_agents_gauge.get_cell("TFAgent").set(True) common.tf_agents_gauge.get_cell(str(type(self))).set(True) if not isinstance(time_step_spec, ts.TimeStep): raise TypeError( "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`." .format(type(time_step_spec))) if num_outer_dims not in [1, 2]: raise ValueError("num_outer_dims must be in [1, 2].") time_step_spec = tensor_spec.from_spec(time_step_spec) action_spec = tensor_spec.from_spec(action_spec) self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy = policy self._collect_policy = collect_policy self._train_sequence_length = train_sequence_length self._num_outer_dims = num_outer_dims self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._enable_summaries = enable_summaries self._training_data_spec = training_data_spec # Data context for data collected directly from the collect policy. self._collect_data_context = data_converter.DataContext( time_step_spec=self._time_step_spec, action_spec=self._action_spec, info_spec=collect_policy.info_spec) # Data context for data passed to train(). May be different if # training_data_spec is provided. if training_data_spec is not None: training_data_spec = tensor_spec.from_spec(training_data_spec) # training_data_spec can be anything; so build a data_context # via best-effort with fall-backs to the collect data spec. training_discount_spec = getattr(training_data_spec, "discount", time_step_spec.discount) training_observation_spec = getattr(training_data_spec, "observation", time_step_spec.observation) training_reward_spec = getattr(training_data_spec, "reward", time_step_spec.reward) training_step_type_spec = getattr(training_data_spec, "step_type", time_step_spec.step_type) training_policy_info_spec = getattr(training_data_spec, "policy_info", collect_policy.info_spec) training_action_spec = getattr(training_data_spec, "action", action_spec) self._data_context = data_converter.DataContext( time_step_spec=ts.TimeStep( discount=training_discount_spec, observation=training_observation_spec, reward=training_reward_spec, step_type=training_step_type_spec), action_spec=training_action_spec, info_spec=training_policy_info_spec) else: self._data_context = data_converter.DataContext( time_step_spec=time_step_spec, action_spec=action_spec, info_spec=collect_policy.info_spec) if train_step_counter is None: train_step_counter = tf.compat.v1.train.get_or_create_global_step() self._train_step_counter = train_step_counter self._train_fn = common.function_in_tf1()(self._train) self._initialize_fn = common.function_in_tf1()(self._initialize) self._preprocess_sequence_fn = common.function_in_tf1()( self._preprocess_sequence) self._loss_fn = common.function_in_tf1()(self._loss)
def __init__(self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, policy: tf_policy.TFPolicy, collect_policy: tf_policy.TFPolicy, train_sequence_length: Optional[int], num_outer_dims: int = 2, training_data_spec: Optional[types.NestedTensorSpec] = None, train_argspec: Optional[Dict[Text, types.NestedTensorSpec]] = None, debug_summaries: bool = False, summarize_grads_and_vars: bool = False, enable_summaries: bool = True, train_step_counter: Optional[tf.Variable] = None, validate_args: bool = True): """Meant to be called by subclass constructors. Args: time_step_spec: A nest of tf.TypeSpec representing the time_steps. Provided by the user. action_spec: A nest of BoundedTensorSpec representing the actions. Provided by the user. policy: An instance of `tf_policy.TFPolicy` representing the Agent's current policy. collect_policy: An instance of `tf_policy.TFPolicy` representing the Agent's current data collection policy (used to set `self.step_spec`). train_sequence_length: A python integer or `None`, signifying the number of time steps required from tensors in `experience` as passed to `train()`. All tensors in `experience` will be shaped `[B, T, ...]` but for certain agents, `T` should be fixed. For example, DQN requires transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set this value to 2. For agents that don't care, or which can handle `T` unknown at graph build time (i.e. most RNN-based agents), set this argument to `None`. num_outer_dims: The number of outer dimensions for the agent. Must be either 1 or 2. If 2, training will require both a batch_size and time dimension on every Tensor; if 1, training will require only a batch_size outer dimension. training_data_spec: A nest of TensorSpec specifying the structure of data the train() function expects. If None, defaults to the trajectory_spec of the collect_policy. train_argspec: (Optional) Describes additional supported arguments to the `train` call. This must be a `dict` mapping strings to nests of specs. Overriding the `experience` arg is also supported. Some algorithms require additional arguments to the `train()` call, and while TF-Agents encourages most of these to be provided in the `policy_info` / `info` field of `experience`, sometimes the extra information doesn't fit well, i.e., when it doesn't come from the policy. **NOTE** kwargs will not have their outer dimensions validated. In particular, `train_sequence_length` is ignored for these inputs, and they may have any, or inconsistent, batch/time dimensions; only their inner shape dimensions are checked against `train_argspec`. Below is an example: ```python class MyAgent(TFAgent): def __init__(self, counterfactual_training, ...): collect_policy = ... train_argspec = None if counterfactual_training: train_argspec = dict( counterfactual=collect_policy.trajectory_spec) super(...).__init__( ... train_argspec=train_argspec) my_agent = MyAgent(...) for ...: experience, counterfactual = next(experience_and_counterfactual_iter) loss_info = my_agent.train(experience, counterfactual=counterfactual) ``` debug_summaries: A bool; if true, subclasses should gather debug summaries. summarize_grads_and_vars: A bool; if true, subclasses should additionally collect gradient and variable summaries. enable_summaries: A bool; if false, subclasses should not gather any summaries (debug or otherwise); subclasses should gate *all* summaries using either `summaries_enabled`, `debug_summaries`, or `summarize_grads_and_vars` properties. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. validate_args: Python bool. Whether to verify inputs to, and outputs of, functions like `train` and `preprocess_sequence` against spec structures, dtypes, and shapes. Research code may prefer to set this value to `False` to allow iterating on input and output structures without being hamstrung by overly rigid checking (at the cost of harder-to-debug errors). See also `TFPolicy.validate_args`. Raises: TypeError: If `validate_args is True` and `train_argspec` is not a `dict`. ValueError: If `validate_args is True` and `train_argspec` has the keys `experience` or `weights`. TypeError: If `validate_args is True` and any leaf nodes in `train_argspec` values are not subclasses of `tf.TypeSpec`. ValueError: If `validate_args is True` and `time_step_spec` is not an instance of `ts.TimeStep`. ValueError: If `num_outer_dims` is not in `[1, 2]`. """ if validate_args: def _each_isinstance(spec, spec_types): """Checks if each element of `spec` is instance of `spec_types`.""" return all( [isinstance(s, spec_types) for s in tf.nest.flatten(spec)]) if not _each_isinstance(time_step_spec, tf.TypeSpec): raise TypeError( "time_step_spec has to contain TypeSpec (TensorSpec, " "SparseTensorSpec, etc) objects, but received: {}".format( time_step_spec)) if not _each_isinstance(action_spec, tensor_spec.BoundedTensorSpec): raise TypeError( "action_spec has to contain BoundedTensorSpec objects, but received: " "{}".format(action_spec)) common.check_tf1_allowed() common.tf_agents_gauge.get_cell("TFAgent").set(True) common.tf_agents_gauge.get_cell(str(type(self))).set(True) if not isinstance(time_step_spec, ts.TimeStep): raise TypeError( "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`." .format(type(time_step_spec))) if num_outer_dims not in [1, 2]: raise ValueError("num_outer_dims must be in [1, 2].") self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy = policy self._collect_policy = collect_policy self._train_sequence_length = train_sequence_length self._num_outer_dims = num_outer_dims self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._enable_summaries = enable_summaries self._training_data_spec = training_data_spec self._validate_args = validate_args # Data context for data collected directly from the collect policy. self._collect_data_context = data_converter.DataContext( time_step_spec=time_step_spec, action_spec=action_spec, info_spec=collect_policy.info_spec) # Data context for data passed to train(). May be different if # training_data_spec is provided. if training_data_spec is not None: # training_data_spec can be anything; so build a data_context # via best-effort with fall-backs to the collect data spec. training_discount_spec = getattr(training_data_spec, "discount", time_step_spec.discount) training_observation_spec = getattr(training_data_spec, "observation", time_step_spec.observation) training_reward_spec = getattr(training_data_spec, "reward", time_step_spec.reward) training_step_type_spec = getattr(training_data_spec, "step_type", time_step_spec.step_type) training_policy_info_spec = getattr(training_data_spec, "policy_info", collect_policy.info_spec) training_action_spec = getattr(training_data_spec, "action", action_spec) self._data_context = data_converter.DataContext( time_step_spec=ts.TimeStep( discount=training_discount_spec, observation=training_observation_spec, reward=training_reward_spec, step_type=training_step_type_spec), action_spec=training_action_spec, info_spec=training_policy_info_spec) else: self._data_context = data_converter.DataContext( time_step_spec=time_step_spec, action_spec=action_spec, info_spec=collect_policy.info_spec) if train_argspec is None: train_argspec = {} elif validate_args: if not isinstance(train_argspec, dict): raise TypeError( "train_argspec must be a dict, but saw: {}".format( train_argspec)) if "weights" in train_argspec or "experience" in train_argspec: raise ValueError( "train_argspec must not override 'weights' or " "'experience' keys, but saw: {}".format(train_argspec)) if not all( isinstance(x, tf.TypeSpec) for x in tf.nest.flatten(train_argspec)): raise TypeError( "train_argspec contains non-TensorSpec objects: {}".format( train_argspec)) train_argspec = dict(train_argspec) # Create a local copy. self._train_argspec = train_argspec if train_step_counter is None: train_step_counter = tf.compat.v1.train.get_or_create_global_step() self._train_step_counter = train_step_counter self._train_fn = common.function_in_tf1()(self._train) self._initialize_fn = common.function_in_tf1()(self._initialize) self._preprocess_sequence_fn = common.function_in_tf1()( self._preprocess_sequence) self._loss_fn = common.function_in_tf1()(self._loss)
def __init__(self, time_step_spec, action_spec, policy, collect_policy, train_sequence_length, num_outer_dims=2, debug_summaries=False, summarize_grads_and_vars=False, enable_summaries=True, train_step_counter=None): """Meant to be called by subclass constructors. Args: time_step_spec: A nest of tf.TypeSpec representing the time_steps. Provided by the user. action_spec: A nest of BoundedTensorSpec representing the actions. Provided by the user. policy: An instance of `tf_policy.Base` representing the Agent's current policy. collect_policy: An instance of `tf_policy.Base` representing the Agent's current data collection policy (used to set `self.step_spec`). train_sequence_length: A python integer or `None`, signifying the number of time steps required from tensors in `experience` as passed to `train()`. All tensors in `experience` will be shaped `[B, T, ...]` but for certain agents, `T` should be fixed. For example, DQN requires transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set this value to 2. For agents that don't care, or which can handle `T` unknown at graph build time (i.e. most RNN-based agents), set this argument to `None`. num_outer_dims: The number of outer dimensions for the agent. Must be either 1 or 2. If 2, training will require both a batch_size and time dimension on every Tensor; if 1, training will require only a batch_size outer dimension. debug_summaries: A bool; if true, subclasses should gather debug summaries. summarize_grads_and_vars: A bool; if true, subclasses should additionally collect gradient and variable summaries. enable_summaries: A bool; if false, subclasses should not gather any summaries (debug or otherwise); subclasses should gate *all* summaries using either `summaries_enabled`, `debug_summaries`, or `summarize_grads_and_vars` properties. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. Raises: ValueError: If `time_step_spec` is not an instance of `ts.TimeStep`. ValueError: If `num_outer_dims` is not in [1, 2]. """ def _each_isinstance(spec, spec_types): """Checks if each element of `spec` is instance of any of `spec_types`.""" return all([isinstance(s, spec_types) for s in tf.nest.flatten(spec)]) if not _each_isinstance(time_step_spec, tf.TypeSpec): raise TypeError("time_step_spec has to contain TypeSpec (TensorSpec, " "SparseTensorSpec, etc) objects, but received: {}".format( time_step_spec)) if not _each_isinstance(action_spec, tensor_spec.BoundedTensorSpec): raise TypeError( "action_spec has to contain BoundedTensorSpec objects, but received: " "{}".format(action_spec)) common.check_tf1_allowed() common.tf_agents_gauge.get_cell("TFAgent").set(True) common.assert_members_are_not_overridden(base_cls=TFAgent, instance=self) if not isinstance(time_step_spec, ts.TimeStep): raise ValueError( "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`." .format(type(time_step_spec))) if num_outer_dims not in [1, 2]: raise ValueError("num_outer_dims must be in [1, 2].") self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy = policy self._collect_policy = collect_policy self._train_sequence_length = train_sequence_length self._num_outer_dims = num_outer_dims self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._enable_summaries = enable_summaries if train_step_counter is None: train_step_counter = tf.compat.v1.train.get_or_create_global_step() self._train_step_counter = train_step_counter self._train_fn = common.function_in_tf1()(self._train) self._initialize_fn = common.function_in_tf1()(self._initialize)