def setUp(self): super().setUp() self._data_context = data_converter.DataContext( time_step_spec=ts.TimeStep(step_type=(), reward=tf.TensorSpec((), tf.float32), discount=tf.TensorSpec((), tf.float32), observation=()), action_spec={'action1': tf.TensorSpec((), tf.float32)}, info_spec=(), policy_state_spec=(), use_half_transition=True) self._data_context_with_state = data_converter.DataContext( time_step_spec=ts.TimeStep(step_type=(), reward=tf.TensorSpec((), tf.float32), discount=tf.TensorSpec((), tf.float32), observation=tf.TensorSpec((2, ), tf.float32)), action_spec={'action1': tf.TensorSpec((), tf.float32)}, info_spec=(), policy_state_spec=[ tf.TensorSpec((2, ), tf.float32), tf.TensorSpec((2, ), tf.float32) ], use_half_transition=True)
def setUp(self): super(AsTrajectoryTest, self).setUp() self._data_context = data_converter.DataContext( time_step_spec=ts.TimeStep(step_type=(), reward=tf.TensorSpec((), tf.float32), discount=tf.TensorSpec((), tf.float32), observation=()), action_spec={'action1': tf.TensorSpec((), tf.float32)}, info_spec=())
def _setup_data_converter(self, q_network, gamma, n_step_update): if q_network.state_spec: if not self._in_graph_bellman_update: self._data_context = data_converter.DataContext( time_step_spec=self._time_step_spec, action_spec=self._action_spec, info_spec=self._collect_policy.info_spec, policy_state_spec=self._q_network.state_spec, use_half_transition=True) self._as_transition = data_converter.AsHalfTransition( self.data_context, squeeze_time_dim=False) else: self._data_context = data_converter.DataContext( time_step_spec=self._time_step_spec, action_spec=self._action_spec, info_spec=self._collect_policy.info_spec, policy_state_spec=self._q_network.state_spec, use_half_transition=False) self._as_transition = data_converter.AsTransition( self.data_context, squeeze_time_dim=False, prepend_t0_to_next_time_step=True) else: if not self._in_graph_bellman_update: self._data_context = data_converter.DataContext( time_step_spec=self._time_step_spec, action_spec=self._action_spec, info_spec=self._collect_policy.info_spec, policy_state_spec=self._q_network.state_spec, use_half_transition=True) self._as_transition = data_converter.AsHalfTransition( self.data_context, squeeze_time_dim=True) else: # This reduces the n-step return and removes the extra time dimension, # allowing the rest of the computations to be independent of the # n-step parameter. self._as_transition = data_converter.AsNStepTransition( self.data_context, gamma=gamma, n=n_step_update)
def __init__(self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, policy: tf_policy.TFPolicy, collect_policy: tf_policy.TFPolicy, train_sequence_length: Optional[int], num_outer_dims: int = 2, training_data_spec: Optional[types.NestedTensorSpec] = None, train_argspec: Optional[Dict[Text, types.NestedTensorSpec]] = None, debug_summaries: bool = False, summarize_grads_and_vars: bool = False, enable_summaries: bool = True, train_step_counter: Optional[tf.Variable] = None, validate_args: bool = True): """Meant to be called by subclass constructors. Args: time_step_spec: A nest of tf.TypeSpec representing the time_steps. Provided by the user. action_spec: A nest of BoundedTensorSpec representing the actions. Provided by the user. policy: An instance of `tf_policy.TFPolicy` representing the Agent's current policy. collect_policy: An instance of `tf_policy.TFPolicy` representing the Agent's current data collection policy (used to set `self.step_spec`). train_sequence_length: A python integer or `None`, signifying the number of time steps required from tensors in `experience` as passed to `train()`. All tensors in `experience` will be shaped `[B, T, ...]` but for certain agents, `T` should be fixed. For example, DQN requires transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set this value to 2. For agents that don't care, or which can handle `T` unknown at graph build time (i.e. most RNN-based agents), set this argument to `None`. num_outer_dims: The number of outer dimensions for the agent. Must be either 1 or 2. If 2, training will require both a batch_size and time dimension on every Tensor; if 1, training will require only a batch_size outer dimension. training_data_spec: A nest of TensorSpec specifying the structure of data the train() function expects. If None, defaults to the trajectory_spec of the collect_policy. train_argspec: (Optional) Describes additional supported arguments to the `train` call. This must be a `dict` mapping strings to nests of specs. Overriding the `experience` arg is also supported. Some algorithms require additional arguments to the `train()` call, and while TF-Agents encourages most of these to be provided in the `policy_info` / `info` field of `experience`, sometimes the extra information doesn't fit well, i.e., when it doesn't come from the policy. **NOTE** kwargs will not have their outer dimensions validated. In particular, `train_sequence_length` is ignored for these inputs, and they may have any, or inconsistent, batch/time dimensions; only their inner shape dimensions are checked against `train_argspec`. Below is an example: ```python class MyAgent(TFAgent): def __init__(self, counterfactual_training, ...): collect_policy = ... train_argspec = None if counterfactual_training: train_argspec = dict( counterfactual=collect_policy.trajectory_spec) super(...).__init__( ... train_argspec=train_argspec) my_agent = MyAgent(...) for ...: experience, counterfactual = next(experience_and_counterfactual_iter) loss_info = my_agent.train(experience, counterfactual=counterfactual) ``` debug_summaries: A bool; if true, subclasses should gather debug summaries. summarize_grads_and_vars: A bool; if true, subclasses should additionally collect gradient and variable summaries. enable_summaries: A bool; if false, subclasses should not gather any summaries (debug or otherwise); subclasses should gate *all* summaries using either `summaries_enabled`, `debug_summaries`, or `summarize_grads_and_vars` properties. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. validate_args: Python bool. Whether to verify inputs to, and outputs of, functions like `train` and `preprocess_sequence` against spec structures, dtypes, and shapes. Research code may prefer to set this value to `False` to allow iterating on input and output structures without being hamstrung by overly rigid checking (at the cost of harder-to-debug errors). See also `TFPolicy.validate_args`. Raises: TypeError: If `validate_args is True` and `train_argspec` is not a `dict`. ValueError: If `validate_args is True` and `train_argspec` has the keys `experience` or `weights`. TypeError: If `validate_args is True` and any leaf nodes in `train_argspec` values are not subclasses of `tf.TypeSpec`. ValueError: If `validate_args is True` and `time_step_spec` is not an instance of `ts.TimeStep`. ValueError: If `num_outer_dims` is not in `[1, 2]`. """ if validate_args: def _each_isinstance(spec, spec_types): """Checks if each element of `spec` is instance of `spec_types`.""" return all( [isinstance(s, spec_types) for s in tf.nest.flatten(spec)]) if not _each_isinstance(time_step_spec, tf.TypeSpec): raise TypeError( "time_step_spec has to contain TypeSpec (TensorSpec, " "SparseTensorSpec, etc) objects, but received: {}".format( time_step_spec)) if not _each_isinstance(action_spec, tensor_spec.BoundedTensorSpec): raise TypeError( "action_spec has to contain BoundedTensorSpec objects, but received: " "{}".format(action_spec)) common.check_tf1_allowed() common.tf_agents_gauge.get_cell("TFAgent").set(True) common.assert_members_are_not_overridden(base_cls=TFAgent, instance=self) if not isinstance(time_step_spec, ts.TimeStep): raise TypeError( "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`." .format(type(time_step_spec))) if num_outer_dims not in [1, 2]: raise ValueError("num_outer_dims must be in [1, 2].") self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy = policy self._collect_policy = collect_policy self._train_sequence_length = train_sequence_length self._num_outer_dims = num_outer_dims self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._enable_summaries = enable_summaries self._training_data_spec = training_data_spec self._validate_args = validate_args # Data context for data collected directly from the collect policy. self._collect_data_context = data_converter.DataContext( time_step_spec=time_step_spec, action_spec=action_spec, info_spec=collect_policy.info_spec) # Data context for data passed to train(). May be different if # training_data_spec is provided. if training_data_spec is not None: data_context_info_spec = getattr(training_data_spec, "policy_info", ()) else: data_context_info_spec = collect_policy.info_spec self._data_context = data_converter.DataContext( time_step_spec=time_step_spec, action_spec=action_spec, info_spec=data_context_info_spec) if train_argspec is None: train_argspec = {} elif validate_args: if not isinstance(train_argspec, dict): raise TypeError( "train_argspec must be a dict, but saw: {}".format( train_argspec)) if "weights" in train_argspec or "experience" in train_argspec: raise ValueError( "train_argspec must not override 'weights' or " "'experience' keys, but saw: {}".format(train_argspec)) if not all( isinstance(x, tf.TypeSpec) for x in tf.nest.flatten(train_argspec)): raise TypeError( "train_argspec contains non-TensorSpec objects: {}".format( train_argspec)) train_argspec = dict(train_argspec) # Create a local copy. self._train_argspec = train_argspec if train_step_counter is None: train_step_counter = tf.compat.v1.train.get_or_create_global_step() self._train_step_counter = train_step_counter self._train_fn = common.function_in_tf1()(self._train) self._initialize_fn = common.function_in_tf1()(self._initialize) self._preprocess_sequence_fn = common.function_in_tf1()( self._preprocess_sequence) self._loss_fn = common.function_in_tf1()(self._loss)
def __init__(self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, policy: tf_policy.TFPolicy, collect_policy: tf_policy.TFPolicy, train_sequence_length: Optional[int], num_outer_dims: int = 2, training_data_spec: Optional[types.NestedTensorSpec] = None, debug_summaries: bool = False, summarize_grads_and_vars: bool = False, enable_summaries: bool = True, train_step_counter: Optional[tf.Variable] = None): """Meant to be called by subclass constructors. Args: time_step_spec: A nest of tf.TypeSpec representing the time_steps. Provided by the user. action_spec: A nest of BoundedTensorSpec representing the actions. Provided by the user. policy: An instance of `tf_policy.TFPolicy` representing the Agent's current policy. collect_policy: An instance of `tf_policy.TFPolicy` representing the Agent's current data collection policy (used to set `self.step_spec`). train_sequence_length: A python integer or `None`, signifying the number of time steps required from tensors in `experience` as passed to `train()`. All tensors in `experience` will be shaped `[B, T, ...]` but for certain agents, `T` should be fixed. For example, DQN requires transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set this value to 2. For agents that don't care, or which can handle `T` unknown at graph build time (i.e. most RNN-based agents), set this argument to `None`. num_outer_dims: The number of outer dimensions for the agent. Must be either 1 or 2. If 2, training will require both a batch_size and time dimension on every Tensor; if 1, training will require only a batch_size outer dimension. training_data_spec: A nest of TensorSpec specifying the structure of data the train() function expects. If None, defaults to the trajectory_spec of the collect_policy. debug_summaries: A bool; if true, subclasses should gather debug summaries. summarize_grads_and_vars: A bool; if true, subclasses should additionally collect gradient and variable summaries. enable_summaries: A bool; if false, subclasses should not gather any summaries (debug or otherwise); subclasses should gate *all* summaries using either `summaries_enabled`, `debug_summaries`, or `summarize_grads_and_vars` properties. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. Raises: ValueError: If `num_outer_dims` is not in `[1, 2]`. """ common.check_tf1_allowed() common.tf_agents_gauge.get_cell("TFAgent").set(True) common.tf_agents_gauge.get_cell(str(type(self))).set(True) if not isinstance(time_step_spec, ts.TimeStep): raise TypeError( "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`." .format(type(time_step_spec))) if num_outer_dims not in [1, 2]: raise ValueError("num_outer_dims must be in [1, 2].") time_step_spec = tensor_spec.from_spec(time_step_spec) action_spec = tensor_spec.from_spec(action_spec) self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy = policy self._collect_policy = collect_policy self._train_sequence_length = train_sequence_length self._num_outer_dims = num_outer_dims self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._enable_summaries = enable_summaries self._training_data_spec = training_data_spec # Data context for data collected directly from the collect policy. self._collect_data_context = data_converter.DataContext( time_step_spec=self._time_step_spec, action_spec=self._action_spec, info_spec=collect_policy.info_spec) # Data context for data passed to train(). May be different if # training_data_spec is provided. if training_data_spec is not None: training_data_spec = tensor_spec.from_spec(training_data_spec) # training_data_spec can be anything; so build a data_context # via best-effort with fall-backs to the collect data spec. training_discount_spec = getattr(training_data_spec, "discount", time_step_spec.discount) training_observation_spec = getattr(training_data_spec, "observation", time_step_spec.observation) training_reward_spec = getattr(training_data_spec, "reward", time_step_spec.reward) training_step_type_spec = getattr(training_data_spec, "step_type", time_step_spec.step_type) training_policy_info_spec = getattr(training_data_spec, "policy_info", collect_policy.info_spec) training_action_spec = getattr(training_data_spec, "action", action_spec) self._data_context = data_converter.DataContext( time_step_spec=ts.TimeStep( discount=training_discount_spec, observation=training_observation_spec, reward=training_reward_spec, step_type=training_step_type_spec), action_spec=training_action_spec, info_spec=training_policy_info_spec) else: self._data_context = data_converter.DataContext( time_step_spec=time_step_spec, action_spec=action_spec, info_spec=collect_policy.info_spec) if train_step_counter is None: train_step_counter = tf.compat.v1.train.get_or_create_global_step() self._train_step_counter = train_step_counter self._train_fn = common.function_in_tf1()(self._train) self._initialize_fn = common.function_in_tf1()(self._initialize) self._preprocess_sequence_fn = common.function_in_tf1()( self._preprocess_sequence) self._loss_fn = common.function_in_tf1()(self._loss)