def testUnstackNestedArray(self): shape = (5, 8) batch_size = 1 specs = self.nest_spec(shape) batched_arrays = self.zeros_from_spec(specs, outer_dims=[batch_size]) unbatched_arrays = nest_utils.unstack_nested_arrays(batched_arrays) self.assertEqual(batch_size, len(unbatched_arrays)) for array in unbatched_arrays: tf.nest.assert_same_structure(specs, array) assert_shapes = lambda a: self.assertEqual(a.shape, shape) tf.nest.map_structure(assert_shapes, unbatched_arrays)
def testUnstackNestedArraysIntoFlatItems(self): shape = (5, 8) batch_size = 3 specs = self.nest_spec(shape) batched_arrays = self.zeros_from_spec(specs, outer_dims=[batch_size]) unbatched_flat_items = nest_utils.unstack_nested_arrays_into_flat_items( batched_arrays) self.assertEqual(batch_size, len(unbatched_flat_items)) for nested_array, flat_item in zip( nest_utils.unstack_nested_arrays(batched_arrays), unbatched_flat_items): self.assertAllEqual(flat_item, tf.nest.flatten(nested_array)) tf.nest.assert_same_structure(specs, tf.nest.pack_sequence_as(specs, flat_item)) assert_shapes = lambda a: self.assertEqual(a.shape, shape) tf.nest.map_structure(assert_shapes, unbatched_flat_items)
def call(self, batched_trajectory: traj.Trajectory): """Processes the batched_trajectory to update the metric. Args: batched_trajectory: A Trajectory containing batches of experience. Raises: ValueError: If the batch size is an unexpected value. """ trajectories = nest_utils.unstack_nested_arrays(batched_trajectory) batch_size = len(trajectories) if not self._built: self.build(batch_size) if batch_size != len(self._metrics): raise ValueError('Batch size {} does not match previously set batch ' 'size {}. Make sure your batch size is set correctly ' 'in BatchedPyMetric initialization and that the batch ' 'size remains constant.'.format(batch_size, len(self._metrics))) for metric, trajectory in zip(self._metrics, trajectories): metric(trajectory)