예제 #1
0
    def testSplitNestedTensors(self):
        shape = [2, 3]
        batch_size = 7

        specs = self.nest_spec(shape)
        batched_tensors = self.zeros_from_spec(specs, batch_size=batch_size)
        tf.nest.assert_same_structure(batched_tensors, specs)

        tensors = nest_utils.split_nested_tensors(batched_tensors, specs,
                                                  batch_size)
        self.assertEqual(batch_size, len(tensors))

        for t in tensors:
            tf.nest.assert_same_structure(specs, t)
        assert_shapes = lambda t: self.assertEqual(t.shape.as_list(), [1] +
                                                   shape)
        tf.nest.map_structure(assert_shapes, tensors)
예제 #2
0
  def testSplitNestedTensorsSizeSplits(self):
    shape = [2, 3]
    batch_size = 9
    size_splits = [2, 4, 3]

    specs = self.nest_spec(shape, include_sparse=False)
    batched_tensors = self.zeros_from_spec(specs, batch_size=batch_size)
    tf.nest.assert_same_structure(batched_tensors, specs)

    tensors = nest_utils.split_nested_tensors(
        batched_tensors, specs, size_splits)
    self.assertEqual(len(tensors), len(size_splits))

    for i, tensor in enumerate(tensors):
      tf.nest.assert_same_structure(specs, tensor)
      tf.nest.map_structure(
          lambda t: self.assertEqual(t.shape.as_list()[0], size_splits[i]),  # pylint: disable=cell-var-from-loop
          tensor)

    assert_shapes = lambda t: self.assertEqual(t.shape.as_list()[1:], shape)
    tf.nest.map_structure(assert_shapes, tensors)
예제 #3
0
  def testSplitNestedTensors(self):
    shape = [2, 3]
    batch_size = 7

    specs = self.nest_spec(shape, include_sparse=True)
    batched_tensors = self.zeros_from_spec(specs, batch_size=batch_size)
    tf.nest.assert_same_structure(batched_tensors, specs)

    tensors = nest_utils.split_nested_tensors(batched_tensors, specs,
                                              batch_size)
    self.assertEqual(batch_size, len(tensors))

    for t in tensors:
      tf.nest.assert_same_structure(specs, t)

    def assert_shapes(t):
      if not tf.executing_eagerly() and isinstance(t, tf.SparseTensor):
        # Constant value propagation in SparseTensors does not allow us to infer
        # the value of output t.shape from input's t.shape; only its rank.
        self.assertEqual(len(t.shape), 1 + len(shape))
      else:
        self.assertEqual(t.shape.as_list(), [1] + shape)
    tf.nest.map_structure(assert_shapes, tensors)
예제 #4
0
파일: particles.py 프로젝트: adak32/bellman
        def action_method_method_wrapper(
                time_step: TimeStep,
                policy_state: NestedTensor = (),
                seed: Optional[Seed] = None) -> PolicyStep:
            """
            The incoming `time_step` has a batch size of `population_size * number_of_particles`.
            This function reduces the batch size of `time_step` to be equal to `population_size`
            only. It does not matter which observations are retained because the policy must be
            state-unconditioned.

            The reduced time step is passed to the policy, and then each action is duplicated
            `number_of_particles` times to create a batch of
            `population_size * number_of_particles` actions.
            """
            reduced_time_step = split_nested_tensors(time_step,
                                                     policy.time_step_spec,
                                                     number_of_particles)[0]

            policy_step = action_method(reduced_time_step, policy_state, seed)
            actions = policy_step.action

            tiled_actions = tile_batch(actions, number_of_particles)

            return policy_step.replace(action=tiled_actions)