def testSplitNestedTensors(self): shape = [2, 3] batch_size = 7 specs = self.nest_spec(shape) batched_tensors = self.zeros_from_spec(specs, batch_size=batch_size) tf.nest.assert_same_structure(batched_tensors, specs) tensors = nest_utils.split_nested_tensors(batched_tensors, specs, batch_size) self.assertEqual(batch_size, len(tensors)) for t in tensors: tf.nest.assert_same_structure(specs, t) assert_shapes = lambda t: self.assertEqual(t.shape.as_list(), [1] + shape) tf.nest.map_structure(assert_shapes, tensors)
def testSplitNestedTensorsSizeSplits(self): shape = [2, 3] batch_size = 9 size_splits = [2, 4, 3] specs = self.nest_spec(shape, include_sparse=False) batched_tensors = self.zeros_from_spec(specs, batch_size=batch_size) tf.nest.assert_same_structure(batched_tensors, specs) tensors = nest_utils.split_nested_tensors( batched_tensors, specs, size_splits) self.assertEqual(len(tensors), len(size_splits)) for i, tensor in enumerate(tensors): tf.nest.assert_same_structure(specs, tensor) tf.nest.map_structure( lambda t: self.assertEqual(t.shape.as_list()[0], size_splits[i]), # pylint: disable=cell-var-from-loop tensor) assert_shapes = lambda t: self.assertEqual(t.shape.as_list()[1:], shape) tf.nest.map_structure(assert_shapes, tensors)
def testSplitNestedTensors(self): shape = [2, 3] batch_size = 7 specs = self.nest_spec(shape, include_sparse=True) batched_tensors = self.zeros_from_spec(specs, batch_size=batch_size) tf.nest.assert_same_structure(batched_tensors, specs) tensors = nest_utils.split_nested_tensors(batched_tensors, specs, batch_size) self.assertEqual(batch_size, len(tensors)) for t in tensors: tf.nest.assert_same_structure(specs, t) def assert_shapes(t): if not tf.executing_eagerly() and isinstance(t, tf.SparseTensor): # Constant value propagation in SparseTensors does not allow us to infer # the value of output t.shape from input's t.shape; only its rank. self.assertEqual(len(t.shape), 1 + len(shape)) else: self.assertEqual(t.shape.as_list(), [1] + shape) tf.nest.map_structure(assert_shapes, tensors)
def action_method_method_wrapper( time_step: TimeStep, policy_state: NestedTensor = (), seed: Optional[Seed] = None) -> PolicyStep: """ The incoming `time_step` has a batch size of `population_size * number_of_particles`. This function reduces the batch size of `time_step` to be equal to `population_size` only. It does not matter which observations are retained because the policy must be state-unconditioned. The reduced time step is passed to the policy, and then each action is duplicated `number_of_particles` times to create a batch of `population_size * number_of_particles` actions. """ reduced_time_step = split_nested_tensors(time_step, policy.time_step_spec, number_of_particles)[0] policy_step = action_method(reduced_time_step, policy_state, seed) actions = policy_step.action tiled_actions = tile_batch(actions, number_of_particles) return policy_step.replace(action=tiled_actions)