Ejemplo n.º 1
0
def flatten_up_to(*args, expand_composites=False, **kwargs):
  # Internal `tree` does not accept check_types here; see b/198436438.
  # Apparently the open-source version of same still does.
#   kwargs.pop('check_types', None)  # DisableOnExport
  if expand_composites:
    raise NotImplementedError(
        '`expand_composites=True` is not supported in JAX.')
  return dm_tree.flatten_up_to(*args, **kwargs)
Ejemplo n.º 2
0
def policy_gradient_loss(policies, actions, action_values, policy_vars=None,
                         name="policy_gradient_loss"):
  """Computes policy gradient losses for a batch of trajectories.

  This wraps `policy_gradient` to accept a possibly nested array of `policies`
  and `actions` in order to allow for multiple action distribution types or
  independent multivariate distributions if not directly available. It also sums
  up losses along the time dimension, and is more restrictive about shapes,
  assuming a [T, B] layout for the `batch_shape` of the policies and a
  concatenate(`[T, B]`, `event_shape` of the policies) shape for the actions.

  Args:
    policies: A (possibly nested structure of) distribution(s) supporting
        `batch_shape` and `event_shape` properties along with a `log_prob`
        method (e.g. an instance of `tfp.distributions.Distribution`),
        with `batch_shape` equal to `[T, B]`.
    actions: A (possibly nested structure of) N-D Tensor(s) with shape
        `[T, B, ...]` where the final dimensions are the `event_shape` of the
        corresponding distribution in the nested structure (the shape can be
        just `[T, B]` if the `event_shape` is scalar).
    action_values: Tensor of shape `[T, B]` containing an estimate of the value
        of the selected `actions`.
    policy_vars: An optional (possibly nested structure of) iterable(s) of
        Tensors used by `policies`. If provided is used in scope checks.
    name: Customises the name_scope for this op.

  Returns:
    loss: Tensor of shape `[B]` containing the total loss for each sequence
    in the batch. Differentiable w.r.t `policy_logits` only.
  """
  actions = nest.flatten(actions)
  if policy_vars:
    policy_vars = nest.flatten_up_to(policies, policy_vars)
  else:
    policy_vars = [list()] * len(actions)
  policies = nest.flatten(policies)

  # Check happens after flatten so that we can be more flexible on nest
  # structures. This is equivalent to asserting that `len(policies) ==
  # len(actions)`, which is sufficient for what we're doing here.
  nest.assert_same_structure(policies, actions)

  for policies_, actions_ in zip(policies, actions):
    policies_.batch_shape.assert_has_rank(2)
    actions_.get_shape().assert_is_compatible_with(
        policies_.batch_shape.concatenate(policies_.event_shape))

  scoped_values = policy_vars + actions + [action_values]
  with tf.name_scope(name, values=scoped_values):
    # Loss for the policy gradient. Doesn't push additional gradients through
    # the action_values.
    policy_gradient_loss_sequence = tf.add_n([
        policy_gradient(policies_, actions_, action_values, pvars)
        for policies_, actions_, pvars in zip(policies, actions, policy_vars)])

    return tf.reduce_sum(
        policy_gradient_loss_sequence, axis=[0],
        name="policy_gradient_loss")
Ejemplo n.º 3
0
    def get_noised_result(self, sample_state, global_state):
        estimates_and_new_global_states = self._map_to_queries(
            'get_noised_result', sample_state, global_state)

        flat_estimates, flat_new_global_states = zip(*tree.flatten_up_to(
            self._queries, estimates_and_new_global_states))
        return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
                tf.nest.pack_sequence_as(self._queries,
                                         flat_new_global_states))
Ejemplo n.º 4
0
  def get_noised_result(self, sample_state, global_state):
    """Implements `tensorflow_privacy.DPQuery.get_noised_result`."""
    mapped_query_results = self._map_to_queries('get_noised_result',
                                                sample_state, global_state)

    flat_estimates, flat_new_global_states, flat_events = zip(
        *tree.flatten_up_to(self._queries, mapped_query_results))

    return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
            tf.nest.pack_sequence_as(self._queries, flat_new_global_states),
            dp_event.ComposedDpEvent(events=flat_events))
Ejemplo n.º 5
0
    def get_noised_result(self, sample_state, global_state):
        """Gets query result after all records of sample have been accumulated.

    Args:
      sample_state: The sample state after all records have been accumulated.
      global_state: The global state.

    Returns:
      A tuple (result, new_global_state) where "result" is a structure matching
      the query structure containing the results of the subqueries and
      "new_global_state" is a structure containing the updated global states
      for the subqueries.
    """
        estimates_and_new_global_states = self._map_to_queries(
            'get_noised_result', sample_state, global_state)

        flat_estimates, flat_new_global_states = zip(*tree.flatten_up_to(
            self._queries, estimates_and_new_global_states))
        return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
                tf.nest.pack_sequence_as(self._queries,
                                         flat_new_global_states))
Ejemplo n.º 6
0
def flatten_up_to(*args, **kwargs):
  # Internal `tree` does not accept check_types here; see b/198436438.
  # Apparently the open-source version of same still does.
#   kwargs.pop('check_types', None)  # DisableOnExport
  return dm_tree.flatten_up_to(*args, **kwargs)
Ejemplo n.º 7
0
 def testByteStringsNotTreatedAsIterable(self):
     structure = [u"unicode string", b"byte string"]
     flattened_structure = tree.flatten_up_to(structure, structure)
     self.assertEqual(structure, flattened_structure)
Ejemplo n.º 8
0
    def testFlattenUpTo(self):
        # Shallow tree ends at scalar.
        input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
        shallow_tree = [[True, True], [False, True]]
        flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree,
                         [[2, 2], [3, 3], [4, 9], [5, 5]])
        self.assertEqual(flattened_shallow_tree, [True, True, False, True])

        # Shallow tree ends at string.
        input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
        shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
        input_tree_flattened_as_shallow_tree = tree.flatten_up_to(
            shallow_tree, input_tree)
        input_tree_flattened = tree.flatten(input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1),
                                                                ("b", 2),
                                                                ("c", 3),
                                                                ("d", 4)])
        self.assertEqual(input_tree_flattened,
                         ["a", 1, "b", 2, "c", 3, "d", 4])

        # Make sure dicts are correctly flattened, yielding values, not keys.
        input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
        shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
        input_tree_flattened_as_shallow_tree = tree.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree,
                         [1, {
                             "c": 2
                         }, 3, (4, 5)])

        # Namedtuples.
        ab_tuple = collections.namedtuple("ab_tuple", "a, b")
        input_tree = ab_tuple(a=[0, 1], b=2)
        shallow_tree = ab_tuple(a=0, b=1)
        input_tree_flattened_as_shallow_tree = tree.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2])

        # Attrs.
        @attr.s
        class ABAttr(object):
            a = attr.ib()
            b = attr.ib()

        input_tree = ABAttr(a=[0, 1], b=2)
        shallow_tree = ABAttr(a=0, b=1)
        input_tree_flattened_as_shallow_tree = tree.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2])

        # Nested dicts, OrderedDicts and namedtuples.
        input_tree = collections.OrderedDict([
            ("a", ab_tuple(a=[0, {
                "b": 1
            }], b=2)), ("c", {
                "d": 3,
                "e": collections.OrderedDict([("f", 4)])
            })
        ])
        shallow_tree = input_tree
        input_tree_flattened_as_shallow_tree = tree.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
        shallow_tree = collections.OrderedDict([("a", 0),
                                                ("c", {
                                                    "d": 3,
                                                    "e": 1
                                                })])
        input_tree_flattened_as_shallow_tree = tree.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [
            ab_tuple(a=[0, {
                "b": 1
            }], b=2), 3,
            collections.OrderedDict([("f", 4)])
        ])
        shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
        input_tree_flattened_as_shallow_tree = tree.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [
            ab_tuple(a=[0, {
                "b": 1
            }], b=2), {
                "d": 3,
                "e": collections.OrderedDict([("f", 4)])
            }
        ])

        ## Shallow non-list edge-case.
        # Using iterable elements.
        input_tree = ["input_tree"]
        shallow_tree = "shallow_tree"
        flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = ["input_tree_0", "input_tree_1"]
        shallow_tree = "shallow_tree"
        flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = [0]
        shallow_tree = 9
        flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = [0, 1]
        shallow_tree = 9
        flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Both non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = "shallow_tree"
        flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = 0
        flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Input non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = ["shallow_tree"]
        with self.assertRaisesWithLiteralMatch(
                TypeError,
                tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
                    type(input_tree))):
            flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = "input_tree"
        shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
        with self.assertRaisesWithLiteralMatch(
                TypeError,
                tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
                    type(input_tree))):
            flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = [9]
        with self.assertRaisesWithLiteralMatch(
                TypeError,
                tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
                    type(input_tree))):
            flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = 0
        shallow_tree = [9, 8]
        with self.assertRaisesWithLiteralMatch(
                TypeError,
                tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
                    type(input_tree))):
            flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)