Exemplo n.º 1
0
    def loss(
        self,
        experience_and_sample_info: ExperienceAndSampleInfo = None,
        reduce_op: tf.distribute.ReduceOp = tf.distribute.ReduceOp.SUM,
    ) -> tf_agent.LossInfo:
        """Computes loss for the experience.

    Since this calls agent.loss() it does not update gradients or
    increment the train step counter. Networks are called with `training=False`
    so statistics like batch norm are not updated.

    Args:
      experience_and_sample_info: A batch of experience and sample info. If
        not specified, `next(self._experience_iterator)` is used.
      reduce_op: a `tf.distribute.ReduceOp` value specifying how loss values
        should be aggregated across replicas.

    Returns:
      The total loss computed.
    """
        def _summary_record_if():
            return tf.math.equal(
                self.train_step % tf.constant(self.summary_interval), 0)

        with self.train_summary_writer.as_default(), \
             common.soft_device_placement(), \
             tf.compat.v2.summary.record_if(_summary_record_if), \
             self.strategy.scope():
            experience_and_sample_info = experience_and_sample_info or next(
                self._experience_iterator)
            loss_info = self._loss(experience_and_sample_info, reduce_op)

            return loss_info
Exemplo n.º 2
0
 def write_metric_summaries(self):
   """Generates scalar summaries for the actor metrics."""
   if self._metrics is None:
     return
   with self._summary_writer.as_default(), \
        common.soft_device_placement(), \
        tf.summary.record_if(lambda: True):
     # Generate summaries against the train_step
     for m in self._metrics:
       tag = m.name
       try:
         tf.summary.scalar(
             name=os.path.join("Metrics/", self._name, tag),
             data=m.result(),
             step=self._train_step)
       except ValueError:
         logging.error("Scalar summary could not be written for metric %s",
                       m)
       # Generate summaries against the reference_metrics
       for reference_metric in self._reference_metrics:
         tag = "Metrics/{}/{}".format(m.name, reference_metric.name)
         try:
           tf.summary.scalar(
               name=os.path.join(self._name, tag),
               data=m.result(),
               step=reference_metric.result())
         except ValueError:
           logging.error(
               "Scalar summary could not be written for reference_metric %s",
               m)
Exemplo n.º 3
0
    def run(self, iterations=1, iterator=None):
        """Runs `iterations` iterations of training.

    Args:
      iterations: Number of train iterations to perform per call to run. The
        iterations will be evaluated in a tf.while loop created by autograph.
        Final aggregated losses will be returned.
      iterator: The iterator to the dataset to use for training. If not
        specified, `self._experience_iterator` is used.

    Returns:
      The total loss computed before running the final step.
    """
        def _summary_record_if():
            return tf.math.equal(
                self.train_step % tf.constant(self.summary_interval), 0)

        with self.train_summary_writer.as_default(), \
             common.soft_device_placement(), \
             tf.compat.v2.summary.record_if(_summary_record_if), \
             self.strategy.scope():
            iterator = iterator or self._experience_iterator
            loss_info = self._train(iterations, iterator)

            train_step_val = self.train_step.numpy()
            for trigger in self.triggers:
                trigger(train_step_val)

            return loss_info
Exemplo n.º 4
0
  def run(self):
    """Train `num_batches` batches repeating for `num_epochs` of iterations.

    Returns:
      The total loss computed before running the final step.
    """
    self._normalization_iterator = iter(self._normalization_dataset_fn())
    num_frames = self._update_normalizers(self._normalization_iterator)
    self.num_frames_for_training.assign(num_frames)

    def _summary_record_if():
      return tf.math.equal(
          self._generic_learner.train_step %
          tf.constant(self._generic_learner.summary_interval), 0)

    if self._minibatch_size:
      num_total_batches = int(self.num_frames_for_training.numpy() /
                              self._minibatch_size) * self._num_epochs
    else:
      num_total_batches = self._num_batches * self._num_epochs

    iterations = int(num_total_batches / self.num_replicas)

    with self._generic_learner.train_summary_writer.as_default(), \
     common.soft_device_placement(), \
     tf.compat.v2.summary.record_if(_summary_record_if), \
     self._generic_learner.strategy.scope():
      loss_info = self._generic_learner.run(iterations, self._train_iterator)

      train_step_val = self._generic_learner.train_step_numpy
      for trigger in self._generic_learner.triggers:
        trigger(train_step_val)

    return loss_info
Exemplo n.º 5
0
def contrastive_img_summary(episode_tuple, agent, summary_writer, train_step):
    """Generates image summaries for the augmented images."""
    _, sim_matrix = agent.contrastive_metric_loss(episode_tuple,
                                                  return_representation=True)
    sim_matrix = tf.expand_dims(tf.expand_dims(sim_matrix, axis=0), axis=-1)
    with summary_writer.as_default(), \
         common.soft_device_placement(), \
      tf.compat.v2.summary.record_if(True):
        tf.summary.image('Sim matrix', sim_matrix, step=train_step)
Exemplo n.º 6
0
    def run(self, iterations, dataset):
        """Runs training until dataset timesout, or when num sequences is reached.

    Args:
      iterations: Number of iterations/epochs to repeat over the collected
        sequences. (Schulman,2017) sets this to 10 for Mujoco, 15 for Roboschool
         and 3 for Atari.
      dataset: A 'tf.Dataset' where each sample is shaped
        [sample_batch_size, sequence_length, ...], commonly the output from
        'reverb_replay_buffer.as_dataset(sample_batch_size, preprocess_fn)'.

    Returns:
      The total loss computed before running the final step.
    """
        # TODO(b/160802425): Verify this setup works with distributed.
        if self._max_num_sequences:
            dataset = dataset.take(self._max_num_sequences)
        cached_dataset = dataset.cache()
        self._update_advantage_normalizer(cached_dataset)

        new_dataset = cached_dataset.repeat(iterations)
        if self._minibatch_size:

            def squash_dataset_element(sequence, info):
                return tf.nest.map_structure(
                    utils.BatchSquash(2).flatten, (sequence, info))

            # We unbatch the dataset shaped [B, T, ...] to a new dataset that contains
            # individual elements.
            # Note that we unbatch across the time dimension, which could result in
            # mini batches that contain subsets from more than one sequences. The PPO
            # agent can handle mini batches across episode boundaries.
            new_dataset = new_dataset.map(squash_dataset_element).unbatch()
            new_dataset = new_dataset.shuffle(self._shuffle_buffer_size)
            new_dataset = new_dataset.batch(1, drop_remainder=True)
            new_dataset = new_dataset.batch(self._minibatch_size,
                                            drop_remainder=True)

        # TODO(b/161133726): use learner.run once it supports None iterations.
        def _summary_record_if():
            return tf.math.equal(
                self._generic_learner.train_step %
                tf.constant(self._generic_learner.summary_interval), 0)

        with self._generic_learner.train_summary_writer.as_default(), \
         common.soft_device_placement(), \
         tf.compat.v2.summary.record_if(_summary_record_if), \
         self._generic_learner.strategy.scope():
            loss_info = self.multi_train_step(iter(new_dataset))

            train_step_val = self._generic_learner.train_step_numpy
            for trigger in self._generic_learner.triggers:
                trigger(train_step_val)

        self._update_normalizers(cached_dataset)

        return loss_info
Exemplo n.º 7
0
 def write_metric_summaries(self):
   """Generates scalar summaries for the actor metrics."""
   super().write_metric_summaries()
   if self._metrics is None:
     return
   with self._summary_writer.as_default(), \
        common.soft_device_placement(), \
        tf.summary.record_if(lambda: True):
     # Generate summaries against the train_step
     for m in self._metrics:
       tag = m.name
       if 'Multiagent' in tag:
         for a in range(m.n_agents):
           tf.compat.v2.summary.scalar(name=tag + '_agent' + str(a),
                                       data=m.result_for_agent(a),
                                       step=self._train_step)
Exemplo n.º 8
0
def img_summary(experience, summary_writer, train_step):
  """Generates image summaries for the augmented images."""
  obs = experience['experience'].observation['pixels']
  if experience['augmented_obs']:
    aug_obs = experience['augmented_obs'][0]['pixels']
    aug_next_obs = experience['augmented_next_obs'][0]['pixels']
    images = tf.stack([
        obs[0, :, :, 0:3],
        aug_obs[:, :, 0:3],
        aug_next_obs[:, :, 0:3],
    ], axis=0)
  else:
    images = tf.expand_dims(obs[0, Ellipsis, 0:3], axis=0)
  with summary_writer.as_default(), \
       common.soft_device_placement(), \
    tf.compat.v2.summary.record_if(True):
    tf.summary.image('Sample crops', images, max_outputs=10, step=train_step)
Exemplo n.º 9
0
 def write_metric_summaries(self):
     """Generates scalar summaries for the actor metrics."""
     if self._metrics is None:
         return
     with self._summary_writer.as_default(), \
          common.soft_device_placement(), \
          tf.summary.record_if(lambda: True):
         # Generate summaries against the train_step
         for m in self._metrics:
             tag = m.name
             tf.summary.scalar(name=os.path.join(self._name, tag),
                               data=m.result(),
                               step=self._train_step)
             # Generate summaries against the reference_metrics
             for reference_metric in self._reference_metrics:
                 tag = "{}/{}".format(m.name, reference_metric.name)
                 tf.summary.scalar(name=os.path.join(self._name, tag),
                                   data=m.result(),
                                   step=reference_metric.result())
Exemplo n.º 10
0
    def run(self, iterations=1, iterator=None, parallel_iterations=10):
        """Runs `iterations` iterations of training.

    Args:
      iterations: Number of train iterations to perform per call to run. The
        iterations will be evaluated in a tf.while loop created by autograph.
        Final aggregated losses will be returned.
      iterator: The iterator to the dataset to use for training. If not
        specified, `self._experience_iterator` is used.
      parallel_iterations: Maximum number of train iterations to allow running
        in parallel. This value is forwarded directly to the training tf.while
        loop.

    Returns:
      The total loss computed before running the final step.
    """
        assert iterations >= 1, (
            'Iterations must be greater or equal to 1, was %d' % iterations)

        def _summary_record_if():
            if self.summary_interval:
                return tf.math.equal(
                    self.train_step % tf.constant(self.summary_interval), 0)
            else:
                return tf.constant(False)

        with self.train_summary_writer.as_default(), \
             common.soft_device_placement(), \
             tf.compat.v2.summary.record_if(_summary_record_if), \
             self.strategy.scope():
            iterator = iterator or self._experience_iterator
            loss_info = self._train(tf.constant(iterations), iterator,
                                    parallel_iterations)

            train_step_val = self.train_step.numpy()
            for trigger in self.triggers:
                trigger(train_step_val)

            return loss_info