Пример #1
0
 def encode(state, value):
     """Encode tf_computation."""
     encoded_structure = nest_contrib.map_structure_up_to(
         encoders, lambda state, value, e: e.encode(value, state), state,
         value, encoders)
     encoded_value = nest_contrib.map_structure_up_to(
         encoders, lambda s: s[0], encoded_structure)
     new_state = nest_contrib.map_structure_up_to(encoders, lambda s: s[1],
                                                  encoded_structure)
     return new_state, encoded_value
Пример #2
0
 def encode_fn(x, encode_params, decode_before_sum_params):
     encoded_structure = nest_contrib.map_structure_up_to(
         encoders, lambda e, *args: e.encode(*args), encoders, x,
         encode_params)
     encoded_x = _slice(encoders, encoded_structure, 0)
     state_update_tensors = _slice(encoders, encoded_structure, 1)
     return encoded_x, decode_before_sum_params, state_update_tensors
Пример #3
0
  def testMapStructureUpTo(self):
    # Named tuples.
    ab_tuple = collections.namedtuple("ab_tuple", "a, b")
    op_tuple = collections.namedtuple("op_tuple", "add, mul")
    inp_val = ab_tuple(a=2, b=3)
    inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
    out = nest.map_structure_up_to(
        inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
    self.assertEqual(out.a, 6)
    self.assertEqual(out.b, 15)

    # Lists.
    data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
    name_list = ["evens", ["odds", "primes"]]
    out = nest.map_structure_up_to(
        name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
        name_list, data_list)
    self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])

    # Dicts.
    inp_val = dict(a=2, b=3)
    inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
    out = nest.map_structure_up_to(
        inp_val,
        lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
    self.assertEqual(out["a"], 6)
    self.assertEqual(out["b"], 15)

    # Non-equal dicts.
    inp_val = dict(a=2, b=3)
    inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
    with self.assertRaisesWithLiteralMatch(
        ValueError,
        nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["b"])):
      nest.map_structure_up_to(
          inp_val,
          lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)

    # Dict+custom mapping.
    inp_val = dict(a=2, b=3)
    inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
    out = nest.map_structure_up_to(
        inp_val,
        lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
    self.assertEqual(out["a"], 6)
    self.assertEqual(out["b"], 15)

    # Non-equal dict/mapping.
    inp_val = dict(a=2, b=3)
    inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
    with self.assertRaisesWithLiteralMatch(
        ValueError,
        nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["b"])):
      nest.map_structure_up_to(
          inp_val,
          lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
Пример #4
0
 def get_params_fn(state):
     params = nest_contrib.map_structure_up_to(encoders,
                                               lambda e, s: e.get_params(s),
                                               encoders, state)
     encode_params = _slice(encoders, params, 0)
     decode_before_sum_params = _slice(encoders, params, 1)
     decode_after_sum_params = _slice(encoders, params, 2)
     return encode_params, decode_before_sum_params, decode_after_sum_params
Пример #5
0
def _slice(encoders, nested_value, idx):
    """Takes a slice of nested values.

  We use a collection of encoders to encode a collection of values. When a
  method of the encoder returns a tuple, e.g., encode / decode params of the
  get_params method, we need to recover the matching collection of encode params
  and collection of decode params. This method is a utility to achieve this.

  Args:
    encoders: A collection of encoders.
    nested_value: A collection of indexable values of the same structure as
      `encoders`.
    idx: An integer. Index of the values in `nested_value` along which to take
      the slice.

  Returns:
    A collection of values of the same structure as `encoders`.
  """
    return nest_contrib.map_structure_up_to(encoders, lambda t: t[idx],
                                            nested_value)
Пример #6
0
 def update_state_fn(state, state_update_tensors):
     return nest_contrib.map_structure_up_to(
         encoders, lambda e, *args: e.update_state(*args), encoders, state,
         state_update_tensors)
Пример #7
0
 def decode_after_sum_fn(summed_values, decode_after_sum_params):
     part_decoded_aggregated_x, num_summands = summed_values
     return nest_contrib.map_structure_up_to(
         encoders,
         lambda e, x, params: e.decode_after_sum(x, params, num_summands),
         encoders, part_decoded_aggregated_x, decode_after_sum_params)
Пример #8
0
 def decode_before_sum_tf_function(encoded_x, decode_before_sum_params):
     part_decoded_x = nest_contrib.map_structure_up_to(
         encoders, lambda e, *args: e.decode_before_sum(*args), encoders,
         encoded_x, decode_before_sum_params)
     one = tf.constant((1, ), tf.int32)
     return part_decoded_x, one
Пример #9
0
 def decode(encoded_value):
     """Decode tf_computation."""
     return nest_contrib.map_structure_up_to(encoders,
                                             lambda e, val: e.decode(val),
                                             encoders, encoded_value)