def test_python_constants_not_exposed(self):
    """Tests that only TensorFlow values are exposed to users."""
    encoder_py = simple_encoder.StatefulSimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.SimpleLinearEncodingStage(2.0, 3.0)).make())
    a_var = tf.get_variable('a_var', initializer=2.0)
    b_var = tf.get_variable('b_var', initializer=3.0)
    encoder_tf = simple_encoder.StatefulSimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.SimpleLinearEncodingStage(a_var, b_var)).make())

    x = tf.constant(1.0)
    encoder_py.initialize()
    encoder_tf.initialize()
    encoded_x_py, decode_fn_py = encoder_py.encode(x)
    decoded_x_py = decode_fn_py(encoded_x_py)
    encoded_x_tf, decode_fn_tf = encoder_tf.encode(x)
    decoded_x_tf = decode_fn_tf(encoded_x_tf)

    # The encoded_x_tf should have two elements that encoded_x_py does not.
    # These correspond to the two variables created passed on to constructor of
    # encoder_tf, which are exposed as params. For encoder_py, these are python
    # integers, and should thus be hidden from users.
    self.assertLen(encoded_x_tf, len(encoded_x_py) + 2)

    # Make sure functionality is still the same.
    self.evaluate(tf.global_variables_initializer())
    x, decoded_x_py, decoded_x_tf = self.evaluate(
        [x, decoded_x_py, decoded_x_tf])
    self.assertAllClose(x, decoded_x_tf)
    self.assertAllClose(x, decoded_x_py)
예제 #2
0
def as_stateful_simple_encoder(encoder):
    """Wraps an `Encoder` object as a `StatefulSimpleEncoder`."""
    if isinstance(encoder, simple_encoder.StatefulSimpleEncoder):
        return encoder
    if not isinstance(encoder, core_encoder.Encoder):
        raise TypeError('The encoder must be an instance of `Encoder`.')
    return simple_encoder.StatefulSimpleEncoder(encoder)
  def test_composite_encoder(self):
    """Tests functionality with a general, composite `Encoder`."""
    encoder = core_encoder.EncoderComposer(
        test_utils.SignIntFloatEncodingStage())
    encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_SIGNS)
    encoder.add_child(test_utils.PlusOneEncodingStage(), SIF_INTS)
    encoder.add_child(test_utils.TimesTwoEncodingStage(), SIF_FLOATS).add_child(
        test_utils.PlusOneOverNEncodingStage(), T2_VALS)
    encoder = simple_encoder.StatefulSimpleEncoder(encoder.make())

    x = tf.constant(1.2)
    encoder.initialize()
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    self.evaluate(tf.global_variables_initializer())
    for i in range(1, 10):
      x_py, encoded_x_py, decoded_x_py = self.evaluate(
          [x, encoded_x, decoded_x])
      self.assertAllClose(x_py, decoded_x_py)
      self.assertAllClose(
          2.0, _encoded_x_field(encoded_x_py, [TENSORS, SIF_SIGNS, T2_VALS]))
      self.assertAllClose(
          2.0, _encoded_x_field(encoded_x_py, [TENSORS, SIF_INTS, P1_VALS]))
      self.assertAllClose(
          0.4 + 1 / i,
          _encoded_x_field(encoded_x_py,
                           [TENSORS, SIF_FLOATS, T2_VALS, PN_VALS]))
 def test_uninitialized_encode_raises(self):
   """Tests uninitialized stateful encoder cannot perform encode."""
   encoder = simple_encoder.StatefulSimpleEncoder(
       core_encoder.EncoderComposer(
           test_utils.PlusOneOverNEncodingStage()).make())
   x = tf.constant(1.0)
   with self.assertRaisesRegex(RuntimeError, 'not been initialized'):
     encoder.encode(x)
 def test_multiple_initialize_raises(self):
   """Tests encoder can be initialized only once."""
   encoder = simple_encoder.StatefulSimpleEncoder(
       core_encoder.EncoderComposer(
           test_utils.PlusOneOverNEncodingStage()).make())
   encoder.initialize()
   with self.assertRaisesRegex(RuntimeError, 'already initialized'):
     encoder.initialize()
 def test_multiple_encode_raises(self):
   """Tests the encode method of stateful encoder can only be called once."""
   encoder = simple_encoder.StatefulSimpleEncoder(
       core_encoder.EncoderComposer(
           test_utils.PlusOneOverNEncodingStage()).make())
   encoder.initialize()
   x = tf.constant(1.0)
   encoder.encode(x)
   with self.assertRaisesRegex(RuntimeError, 'only once'):
     encoder.encode(x)
  def test_decode_needs_input_shape_dynamic(self):
    """Tests that mechanism for passing input shape works with dynamic shape."""
    encoder = simple_encoder.StatefulSimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).make())

    x = test_utils.get_tensor_with_random_shape()
    encoder.initialize()
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    self.evaluate(tf.global_variables_initializer())
    x, decoded_x = self.evaluate([x, decoded_x])
    self.assertAllEqual(x.shape, decoded_x.shape)
  def test_decode_needs_input_shape_static(self):
    """Tests that mechanism for passing input shape works with static shape."""
    encoder = simple_encoder.StatefulSimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).make())

    x = tf.reshape(list(range(15)), [3, 5])
    encoder.initialize()
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    self.evaluate(tf.global_variables_initializer())
    decoded_x = self.evaluate(decoded_x)
    self.assertAllEqual([[7.0] * 5] * 3, decoded_x)
  def test_basic_encode_decode(self):
    """Tests basic encoding and decoding works as expected."""
    encoder = simple_encoder.StatefulSimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.PlusOneOverNEncodingStage()).make())

    x = tf.constant(1.0)
    encoder.initialize()
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    self.evaluate(tf.global_variables_initializer())
    for i in range(1, 10):
      x_py, encoded_x_py, decoded_x_py = self.evaluate(
          [x, encoded_x, decoded_x])
      self.assertAllClose(x_py, decoded_x_py)
      self.assertAllClose(1.0 + 1 / i,
                          _encoded_x_field(encoded_x_py, [TENSORS, PN_VALS]))
 def test_initializer_raises(self, not_an_encoder):
   """Tests invalid encoder argument."""
   with self.assertRaisesRegex(TypeError, 'Encoder'):
     simple_encoder.StatefulSimpleEncoder(not_an_encoder)