Ejemplo n.º 1
0
    def testUnstackNestedArray(self):
        shape = (5, 8)
        batch_size = 1

        specs = self.nest_spec(shape)
        batched_arrays = self.zeros_from_spec(specs, outer_dims=[batch_size])
        unbatched_arrays = nest_utils.unstack_nested_arrays(batched_arrays)
        self.assertEqual(batch_size, len(unbatched_arrays))

        for array in unbatched_arrays:
            tf.nest.assert_same_structure(specs, array)
        assert_shapes = lambda a: self.assertEqual(a.shape, shape)
        tf.nest.map_structure(assert_shapes, unbatched_arrays)
Ejemplo n.º 2
0
  def testUnstackNestedArraysIntoFlatItems(self):
    shape = (5, 8)
    batch_size = 3

    specs = self.nest_spec(shape)
    batched_arrays = self.zeros_from_spec(specs, outer_dims=[batch_size])
    unbatched_flat_items = nest_utils.unstack_nested_arrays_into_flat_items(
        batched_arrays)
    self.assertEqual(batch_size, len(unbatched_flat_items))

    for nested_array, flat_item in zip(
        nest_utils.unstack_nested_arrays(batched_arrays), unbatched_flat_items):
      self.assertAllEqual(flat_item, tf.nest.flatten(nested_array))
      tf.nest.assert_same_structure(specs,
                                    tf.nest.pack_sequence_as(specs, flat_item))
    assert_shapes = lambda a: self.assertEqual(a.shape, shape)
    tf.nest.map_structure(assert_shapes, unbatched_flat_items)
Ejemplo n.º 3
0
  def call(self, batched_trajectory: traj.Trajectory):
    """Processes the batched_trajectory to update the metric.

    Args:
      batched_trajectory: A Trajectory containing batches of experience.

    Raises:
      ValueError: If the batch size is an unexpected value.
    """
    trajectories = nest_utils.unstack_nested_arrays(batched_trajectory)
    batch_size = len(trajectories)
    if not self._built:
      self.build(batch_size)
    if batch_size != len(self._metrics):
      raise ValueError('Batch size {} does not match previously set batch '
                       'size {}. Make sure your batch size is set correctly '
                       'in BatchedPyMetric initialization and that the batch '
                       'size remains constant.'.format(batch_size,
                                                       len(self._metrics)))

    for metric, trajectory in zip(self._metrics, trajectories):
      metric(trajectory)