Example #1
0
    def test_python_constants_not_exposed(self):
        """Tests that only TensorFlow values are exposed to users."""
        x = tf.constant(1.0)
        tensorspec = tf.TensorSpec.from_tensor(x)
        encoder_py = simple_encoder.SimpleEncoder(
            core_encoder.EncoderComposer(
                test_utils.SimpleLinearEncodingStage(2.0, 3.0)).make(),
            tensorspec)
        a_var = tf.compat.v1.get_variable('a_var', initializer=2.0)
        b_var = tf.compat.v1.get_variable('b_var', initializer=3.0)
        encoder_tf = simple_encoder.SimpleEncoder(
            core_encoder.EncoderComposer(
                test_utils.SimpleLinearEncodingStage(a_var, b_var)).make(),
            tensorspec)

        state_py = encoder_py.initial_state()
        state_tf = encoder_tf.initial_state()
        iteration_py = _make_iteration_function(encoder_py)
        iteration_tf = _make_iteration_function(encoder_tf)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        _, encoded_x_py, decoded_x_py, _ = self.evaluate(
            iteration_py(x, state_py))
        _, encoded_x_tf, decoded_x_tf, _ = self.evaluate(
            iteration_tf(x, state_tf))

        # The encoded_x_tf should have two elements that encoded_x_py does not.
        # These correspond to the two variables created passed on to constructor of
        # encoder_tf, which are exposed as params. For encoder_py, these are python
        # integers, and should thus be hidden from users.
        self.assertLen(encoded_x_tf, len(encoded_x_py) + 2)

        # Make sure functionality is still the same.
        self.assertAllClose(x, decoded_x_tf)
        self.assertAllClose(x, decoded_x_py)
    def test_python_constants_not_exposed(self):
        """Tests that only TensorFlow values are exposed to users."""
        x_fn = lambda: tf.constant(1.0)
        tensorspec = tf.TensorSpec.from_tensor(x_fn())
        encoder_py = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.SimpleLinearEncodingStage(2.0, 3.0)).add_parent(
                    test_utils.PlusOneEncodingStage(), P1_VALS).add_parent(
                        test_utils.SimpleLinearEncodingStage(2.0, 3.0),
                        SL_VALS).make(), tensorspec)
        a_var = tf.compat.v1.get_variable('a_var', initializer=2.0)
        b_var = tf.compat.v1.get_variable('b_var', initializer=3.0)
        encoder_tf = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.SimpleLinearEncodingStage(a_var, b_var)).add_parent(
                    test_utils.PlusOneEncodingStage(), P1_VALS).add_parent(
                        test_utils.SimpleLinearEncodingStage(a_var, b_var),
                        SL_VALS).make(), tensorspec)

        (encode_params_py, decode_before_sum_params_py,
         decode_after_sum_params_py) = encoder_py.get_params()
        (encode_params_tf, decode_before_sum_params_tf,
         decode_after_sum_params_tf) = encoder_tf.get_params()

        # Params that are Python constants -- not tf.Tensors -- should be hidden
        # from the user, and made statically available at appropriate locations.
        self.assertLen(encode_params_py, 1)
        self.assertLen(encode_params_tf, 5)
        self.assertLen(decode_before_sum_params_py, 1)
        self.assertLen(decode_before_sum_params_tf, 3)
        self.assertEmpty(decode_after_sum_params_py)
        self.assertLen(decode_after_sum_params_tf, 2)
    def test_full_commutativity_with_sum(self):
        """Tests that fully commutes with sum property works."""
        spec = tf.TensorSpec((2, ), tf.float32)

        encoder = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.TimesTwoEncodingStage()).make(), spec)
        self.assertTrue(encoder.fully_commutes_with_sum)

        encoder = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.TimesTwoEncodingStage()).add_parent(
                    test_utils.TimesTwoEncodingStage(), T2_VALS).make(), spec)
        self.assertTrue(encoder.fully_commutes_with_sum)

        encoder = core_encoder.EncoderComposer(
            test_utils.SignIntFloatEncodingStage())
        encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
        encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
        encoder.add_child(test_utils.TimesTwoEncodingStage(),
                          SIF_FLOATS).add_child(
                              test_utils.PlusOneOverNEncodingStage(), T2_VALS)
        encoder = gather_encoder.GatherEncoder.from_encoder(
            encoder.make(), spec)
        self.assertFalse(encoder.fully_commutes_with_sum)
Example #4
0
    def test_add_child_semantics(self):
        composer = core_encoder.EncoderComposer(
            test_utils.TimesTwoEncodingStage())
        composer.add_child(test_utils.PlusOneEncodingStage(), T2_VALS)
        encoder_1 = composer.make()
        encoder_2 = core_encoder.EncoderComposer(
            test_utils.TimesTwoEncodingStage()).add_child(
                test_utils.PlusOneEncodingStage(), T2_VALS).make()

        # Assert that these produce different trees. The add_child method returns
        # the newly created node, and thus the make creates only the child node.
        self.assertNotEqual(encoder_1.children.keys(),
                            encoder_2.children.keys())
  def test_composite_encoder(self):
    """Tests functionality with a general, composite `Encoder`."""
    encoder = core_encoder.EncoderComposer(
        test_utils.SignIntFloatEncodingStage())
    encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
    encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
    encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_FLOATS).add_child(
        test_utils.PlusOneOverNEncodingStage(), T2_VALS)
    encoder = simple_encoder.StatefulSimpleEncoder(encoder.make())

    x = tf.constant(1.2)
    encoder.initialize()
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    self.evaluate(tf.global_variables_initializer())
    for i in range(1, 10):
      x_py, encoded_x_py, decoded_x_py = self.evaluate(
          [x, encoded_x, decoded_x])
      self.assertAllClose(x_py, decoded_x_py)
      self.assertAllClose(
          2.0, _encoded_x_field(encoded_x_py, [TENSORS, SIF_SIGNS, T2_VALS]))
      self.assertAllClose(
          2.0, _encoded_x_field(encoded_x_py, [TENSORS, SIF_INTS, P1_VALS]))
      self.assertAllClose(
          0.4 + 1 / i,
          _encoded_x_field(encoded_x_py,
                           [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
Example #6
0
    def test_composite_encoder(self):
        """Tests functionality with a general, composite `Encoder`."""
        x = tf.constant(1.2)
        encoder = core_encoder.EncoderComposer(
            test_utils.SignIntFloatEncodingStage())
        encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
        encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
        encoder.add_child(test_utils.TimesTwoEncodingStage(),
                          SIF_FLOATS).add_child(
                              test_utils.PlusOneOverNEncodingStage(), T2_VALS)
        encoder = simple_encoder.SimpleEncoder(encoder.make(),
                                               tf.TensorSpec.from_tensor(x))

        state = encoder.initial_state()
        iteration = _make_iteration_function(encoder)
        for i in range(1, 10):
            x, encoded_x, decoded_x, state = self.evaluate(iteration(x, state))
            self.assertAllClose(x, decoded_x)
            self.assertAllClose(
                2.0, _encoded_x_field(encoded_x,
                                      [TENSORS, SIF_SIGNS, T2_VALS]))
            self.assertAllClose(
                2.0, _encoded_x_field(encoded_x, [TENSORS, SIF_INTS, P1_VALS]))
            self.assertAllClose(
                0.4 + 1 / i,
                _encoded_x_field(encoded_x,
                                 [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
Example #7
0
 def test_add_child_parent_bad_key_raises(self):
     encoder = core_encoder.EncoderComposer(
         test_utils.TimesTwoEncodingStage())
     with self.assertRaises(KeyError):
         encoder.add_child(test_utils.PlusOneEncodingStage(), '___bad_key')
     with self.assertRaises(KeyError):
         encoder.add_parent(test_utils.PlusOneEncodingStage(), '___bad_key')
Example #8
0
    def test_decode_split_commutes_with_sum_true_true(self):
        """Tests that splitting decode works as expected with commutes_with_sum.

    This test chains two encoding stages, first *does* commute with sum, the
    second *does*, too. Together, everything should commute with sum.
    """
        encoder = core_encoder.EncoderComposer(
            test_utils.TimesTwoEncodingStage()).add_parent(
                test_utils.TimesTwoEncodingStage(), T2_VALS).make()
        self.assertTrue(encoder.fully_commutes_with_sum)
        data = self._data_for_test_decode_split(encoder, tf.constant(3.0))

        # Test the encoding is as expected.
        self.assertEqual(data.x, data.decoded_x_after_sum)
        self.assertAllEqual({
            T2_VALS: {
                T2_VALS: data.x * 2.0 * 2.0
            },
        }, data.encoded_x)
        # Everyting commutes with sum - decoded_x_before_sum should be intact.
        self.assertAllEqual(data.encoded_x, data.decoded_x_before_sum)
        self.assertEqual(
            {
                COMMUTE: True,
                CHILDREN: {
                    T2_VALS: {
                        COMMUTE: True,
                        CHILDREN: {}
                    }
                }
            }, encoder.commuting_structure)
 def test_not_a_tensorspec_raises(self, not_a_tensorspec):
     """Tests invalid type of tensorspec argument."""
     encoder = core_encoder.EncoderComposer(
         test_utils.PlusOneOverNEncodingStage()).make()
     with self.assertRaisesRegex(TypeError, 'TensorSpec'):
         gather_encoder.GatherEncoder.from_encoder(encoder,
                                                   not_a_tensorspec)
Example #10
0
def uniform_quantization(bits):
    """Returns uniform quanitzation `Encoder`.

  The `Encoder` first reshapes the input to a rank-1 `Tensor`, then applies
  uniform quantization with the extreme values being the minimum and maximum of
  the vector being encoded. Finally, the quantized values are bitpacked to an
  integer type.

  The `Encoder` is a composition of the following encoding stages:
  * `FlattenEncodingStage`
  * `UniformQuantizationEncodingStage`
  * `BitpackingEncodingStage`

  Args:
    bits: Number of bits to quantize into.

  Returns:
    The quantization `Encoder`.
  """
    return core_encoder.EncoderComposer(
        stages_impl.BitpackingEncodingStage(bits)).add_parent(
            stages_impl.UniformQuantizationEncodingStage(bits), stages_impl.
            UniformQuantizationEncodingStage.ENCODED_VALUES_KEY).add_parent(
                stages_impl.FlattenEncodingStage(),
                stages_impl.FlattenEncodingStage.ENCODED_VALUES_KEY).make()
Example #11
0
def hadamard_quantization(bits):
    """Returns hadamard quanitzation `Encoder`.

  The `Encoder` first reshapes the input to a rank-1 `Tensor`, and applies the
  Hadamard transform (rotation). It then applies uniform quantization with the
  extreme values being the minimum and maximum of the rotated vector being
  encoded. Finally, the quantized values are bitpacked to an integer type.

  The `Encoder` is a composition of the following encoding stages:
  * `FlattenEncodingStage` - reshaping the input to a vector.
  * `HadamardEncodingStage` - applying the Hadamard transform.
  * `UniformQuantizationEncodingStage` - applying uniform quantization.
  * `BitpackingEncodingStage` - bitpacking the result into integer values.

  Args:
    bits: Number of bits to quantize into.

  Returns:
    The hadamard quantization `Encoder`.
  """
    return core_encoder.EncoderComposer(
        stages_impl.BitpackingEncodingStage(bits)).add_parent(
            stages_impl.UniformQuantizationEncodingStage(bits),
            stages_impl.UniformQuantizationEncodingStage.ENCODED_VALUES_KEY
        ).add_parent(
            stages_impl.HadamardEncodingStage(),
            stages_impl.HadamardEncodingStage.ENCODED_VALUES_KEY).add_parent(
                stages_impl.FlattenEncodingStage(),
                stages_impl.FlattenEncodingStage.ENCODED_VALUES_KEY).make()
Example #12
0
    def test_commutes_with_sum(self):
        """Tests that commutativity works, provided appropriate num_summands."""
        encoder = core_encoder.EncoderComposer(
            test_utils.PlusOneEncodingStage()).add_parent(
                test_utils.SimpleLinearEncodingStage(2.0, 3.0),
                SL_VALS).make()

        x = tf.constant(3.0)
        encode_params, decode_params = encoder.get_params(
            encoder.initial_state())
        encoded_x, _, input_shapes = encoder.encode(x, encode_params)
        decoded_x_before_sum = encoder.decode_before_sum(
            encoded_x, decode_params, input_shapes)
        # Trivial summing of the encoded - and partially decoded - values.
        part_decoded_and_summed_x = tf.nest.map_structure(
            lambda x: x + x + x, decoded_x_before_sum)
        num_summands = 3
        decoded_x_after_sum = encoder.decode_after_sum(
            part_decoded_and_summed_x, decode_params, num_summands,
            input_shapes)
        data = self.evaluate(
            self.commutation_test_data(x, encoded_x, decoded_x_before_sum,
                                       decoded_x_after_sum))
        self.assertEqual(3.0, data.x)
        expected_encoded_x = {SL_VALS: {P1_VALS: (data.x * 2.0 + 3.0) + 1.0}}
        self.assertAllEqual(expected_encoded_x, data.encoded_x)
        expected_decoded_x_before_sum = {SL_VALS: data.x * 2.0 + 3.0}
        self.assertAllEqual(expected_decoded_x_before_sum,
                            data.decoded_x_before_sum)
        self.assertEqual(9.0, data.decoded_x_after_sum)
Example #13
0
    def test_decode_split_commutes_with_sum_true_false(self):
        """Tests that splitting decode works as expected with commutes_with_sum.

    This test chains two encoding stages, first *does* commute with sum, the
    second *does not*. Together, only the first one should commute with sum.
    """
        encoder = core_encoder.EncoderComposer(
            test_utils.PlusOneEncodingStage()).add_parent(
                test_utils.TimesTwoEncodingStage(), T2_VALS).make()
        self.assertFalse(encoder.fully_commutes_with_sum)
        data = self._data_for_test_decode_split(encoder, tf.constant(3.0))

        # Test the encoding is as expected.
        self.assertEqual(data.x, data.decoded_x_after_sum)
        self.assertAllEqual({
            T2_VALS: {
                P1_VALS: data.x * 2.0 + 1.0
            },
        }, data.encoded_x)
        # Only first part commutes with sum.
        self.assertAllEqual({T2_VALS: data.x * 2.0}, data.decoded_x_before_sum)
        self.assertEqual(
            {
                COMMUTE: True,
                CHILDREN: {
                    T2_VALS: {
                        COMMUTE: False,
                        CHILDREN: {}
                    }
                }
            }, encoder.commuting_structure)
Example #14
0
 def test_input_tensorspec(self):
     x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
     encoder = simple_encoder.SimpleEncoder(
         core_encoder.EncoderComposer(
             test_utils.PlusOneOverNEncodingStage()).make(),
         tf.TensorSpec.from_tensor(x))
     self.assertTrue(encoder.input_tensorspec.is_compatible_with(x))
Example #15
0
    def test_tree_encoder(self):
        """Tests that the encoder works as a proper tree, not only a chain."""
        encoder = core_encoder.EncoderComposer(
            test_utils.SignIntFloatEncodingStage())
        encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
        encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
        encoder.add_child(test_utils.PlusOneEncodingStage(),
                          SIF_FLOATS).add_child(
                              test_utils.TimesTwoEncodingStage(), P1_VALS)
        encoder = encoder.make()

        x = tf.constant([0.0, 0.1, -0.1, 0.9, -0.9, 1.6, -2.2])
        encode_params, decode_params = encoder.get_params(
            encoder.initial_state())
        encoded_x, _, input_shapes = encoder.encode(x, encode_params)
        decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
        x, encoded_x, decoded_x = self.evaluate([x, encoded_x, decoded_x])

        self.assertAllClose(x, decoded_x)
        expected_encoded_x = {
            SIF_SIGNS: {
                T2_VALS: np.array([0.0, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0])
            },
            SIF_INTS: {
                P1_VALS: np.array([1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 3.0])
            },
            SIF_FLOATS: {
                P1_VALS: {
                    T2_VALS: np.array([2.0, 2.2, 2.2, 3.8, 3.8, 3.2, 2.4])
                }
            }
        }
        self.assertAllClose(expected_encoded_x, encoded_x)
Example #16
0
    def test_adaptive_stage_using_state_update_tensors(self):
        """Tests adaptive encoding stage with state update tensors."""
        encoder = core_encoder.EncoderComposer(
            test_utils.AdaptiveNormalizeEncodingStage()).add_parent(
                test_utils.PlusOneEncodingStage(), P1_VALS).make()
        x = tf.constant(1.0)
        state = encoder.initial_state()

        for _ in range(1, 5):
            initial_state = state
            encode_params, decode_params = encoder.get_params(state)
            encoded_x, state_update_tensors, input_shapes = encoder.encode(
                x, encode_params)
            decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
            state = encoder.update_state(initial_state, state_update_tensors)
            data = self.evaluate(
                test_utils.TestData(x, encoded_x, decoded_x, initial_state,
                                    state_update_tensors, state))

            self.assertAllClose(data.x, data.decoded_x)
            self.assertLessEqual(
                data.initial_state[CHILDREN][P1_VALS][STATE][AN_FACTOR_STATE],
                1.0)
            self.assertEqual(
                data.state_update_tensors[CHILDREN][P1_VALS][TENSORS]
                [AN_NORM_UPDATE], 2.0)
            self.assertLessEqual(data.encoded_x[P1_VALS][AN_VALS], 2.0)
 def test_not_fully_defined_shape_raises(self):
     """Tests tensorspec without fully defined shape."""
     encoder = core_encoder.EncoderComposer(
         test_utils.PlusOneOverNEncodingStage()).make()
     with self.assertRaisesRegex(TypeError, 'fully defined'):
         gather_encoder.GatherEncoder.from_encoder(
             encoder, tf.TensorSpec((None, ), tf.float32))
    def test_param_control_from_outside(self):
        """Tests that behavior can be controlled from outside, if needed."""
        a_var = tf.get_variable('a_var', initializer=2.0)
        b_var = tf.get_variable('b_var', initializer=3.0)
        encoder = simple_encoder.SimpleEncoder(
            core_encoder.EncoderComposer(
                test_utils.SimpleLinearEncodingStage(a_var, b_var)).make())

        x = tf.constant(1.0)
        encoded_x, decode_fn = encoder.encode(x)
        decoded_x = decode_fn(encoded_x)

        self.evaluate(tf.global_variables_initializer())
        x_py, encoded_x_py, decoded_x_py = self.evaluate(
            [x, encoded_x, decoded_x])
        self.assertAllClose(x_py, decoded_x_py)
        self.assertAllClose(5.0,
                            _encoded_x_field(encoded_x_py, [TENSORS, SL_VALS]))

        # Change to variables should change the behavior of the encoder.
        self.evaluate([tf.assign(a_var, 5.0), tf.assign(b_var, -7.0)])
        x_py, encoded_x_py, decoded_x_py = self.evaluate(
            [x, encoded_x, decoded_x])
        self.assertAllClose(x_py, decoded_x_py)
        self.assertAllClose(-2.0,
                            _encoded_x_field(encoded_x_py, [TENSORS, SL_VALS]))
 def test_input_tensorspec(self):
     """Tests input_tensorspec property."""
     x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
     encoder = gather_encoder.GatherEncoder.from_encoder(
         core_encoder.EncoderComposer(
             test_utils.PlusOneOverNEncodingStage()).make(),
         tf.TensorSpec.from_tensor(x))
     self.assertTrue(encoder.input_tensorspec.is_compatible_with(x))
 def test_multiple_initialize_raises(self):
   """Tests encoder can be initialized only once."""
   encoder = simple_encoder.StatefulSimpleEncoder(
       core_encoder.EncoderComposer(
           test_utils.PlusOneOverNEncodingStage()).make())
   encoder.initialize()
   with self.assertRaisesRegex(RuntimeError, 'already initialized'):
     encoder.initialize()
 def test_uninitialized_encode_raises(self):
   """Tests uninitialized stateful encoder cannot perform encode."""
   encoder = simple_encoder.StatefulSimpleEncoder(
       core_encoder.EncoderComposer(
           test_utils.PlusOneOverNEncodingStage()).make())
   x = tf.constant(1.0)
   with self.assertRaisesRegex(RuntimeError, 'not been initialized'):
     encoder.encode(x)
 def test_multiple_encode_raises(self):
   """Tests the encode method of stateful encoder can only be called once."""
   encoder = simple_encoder.StatefulSimpleEncoder(
       core_encoder.EncoderComposer(
           test_utils.PlusOneOverNEncodingStage()).make())
   encoder.initialize()
   x = tf.constant(1.0)
   encoder.encode(x)
   with self.assertRaisesRegex(RuntimeError, 'only once'):
     encoder.encode(x)
    def test_decode_needs_input_shape(self):
        """Tests that mechanism for passing input shape works."""
        x_fn = lambda: tf.reshape(list(range(15)), [3, 5])
        encoder = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.ReduceMeanEncodingStage()).make(),
            tf.TensorSpec.from_tensor(x_fn()))

        iteration = _make_iteration_function(encoder, x_fn, 1)
        data = self.evaluate(iteration(encoder.initial_state()))
        self.assertAllEqual([[7.0] * 5] * 3, data.decoded_x)
  def test_modifying_encoded_x_raises(self):
    """Tests decode_fn raises if the encoded_x dictionary is modified."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(test_utils.PlusOneEncodingStage()).make())

    x = tf.constant(1.0)
    encoded_x, decode_fn = encoder.encode(x)
    encoded_x['__NOT_EXPECTED_KEY__'] = None
    with self.assertRaises(ValueError):
      decode_fn(encoded_x)
    with self.assertRaises(ValueError):
      decode_fn({})
  def test_decode_needs_input_shape_dynamic(self):
    """Tests that mechanism for passing input shape works with dynamic shape."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).make())

    x = test_utils.get_tensor_with_random_shape()
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    x, decoded_x = self.evaluate([x, decoded_x])
    self.assertAllEqual(x.shape, decoded_x.shape)
Example #26
0
    def test_decode_needs_input_shape_static(self):
        """Tests that mechanism for passing input shape works with static shape."""
        x = tf.reshape(list(range(15)), [3, 5])
        encoder = simple_encoder.SimpleEncoder(
            core_encoder.EncoderComposer(
                test_utils.ReduceMeanEncodingStage()).make(),
            tf.TensorSpec.from_tensor(x))

        state = encoder.initial_state()
        iteration = _make_iteration_function(encoder)
        _, _, decoded_x, _ = self.evaluate(iteration(x, state))
        self.assertAllEqual([[7.0] * 5] * 3, decoded_x)
  def test_decode_needs_input_shape_static(self):
    """Tests that mechanism for passing input shape works with static shape."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).make())

    x = tf.reshape(list(range(15)), [3, 5])
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    decoded_x = self.evaluate(decoded_x)
    self.assertAllEqual([[7.0] * 5] * 3, decoded_x)
  def test_basic_encode_decode(self):
    """Tests basic encoding and decoding works as expected."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(test_utils.PlusOneEncodingStage()).make())

    x = tf.constant(1.0)
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    x, encoded_x, decoded_x = self.evaluate([x, encoded_x, decoded_x])
    self.assertAllClose(x, decoded_x)
    self.assertAllClose(2.0, _encoded_x_field(encoded_x, [TENSORS, P1_VALS]))
  def test_encode_multiple_objects(self):
    """Tests the same object can encode multiple different objects."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(test_utils.PlusOneEncodingStage()).make())

    for shape in [(2,), (2, 3), (2, 3, 4)]:
      x = tf.constant(np.ones(shape, np.float32))
      encoded_x, decode_fn = encoder.encode(x)
      decoded_x = decode_fn(encoded_x)

      x, encoded_x, decoded_x = self.evaluate([x, encoded_x, decoded_x])
      self.assertAllClose(x, decoded_x)
      self.assertAllClose(
          2.0 * np.ones(shape, np.float32),
          _encoded_x_field(encoded_x, [TENSORS, P1_VALS]))
Example #30
0
    def test_basic_encode_decode(self):
        """Tests basic encoding and decoding works as expected."""
        x = tf.constant(1.0, tf.float32)
        encoder = simple_encoder.SimpleEncoder(
            core_encoder.EncoderComposer(
                test_utils.PlusOneOverNEncodingStage()).make(),
            tf.TensorSpec.from_tensor(x))

        state = encoder.initial_state()
        iteration = _make_iteration_function(encoder)
        for i in range(1, 10):
            x, encoded_x, decoded_x, state = self.evaluate(iteration(x, state))
            self.assertAllClose(x, decoded_x)
            self.assertAllClose(
                1.0 + 1 / i, _encoded_x_field(encoded_x, [TENSORS, PN_VALS]))