Пример #1
0
    def test_decode_needs_input_shape(self):
        """Tests that encoder works with stages that need input shape for decode.

    This test chains two stages with this property.
    """
        encoder = core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).add_parent(
                test_utils.ReduceMeanEncodingStage(), RM_VALS).make()
        x = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0])
        encode_params, decode_params = encoder.get_params(
            encoder.initial_state())
        encoded_x, _, input_shapes = encoder.encode(x, encode_params)
        decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
        encoded_x, decoded_x = self.evaluate([encoded_x, decoded_x])

        self.assertAllEqual([3.0] * 5, decoded_x)
        self.assertAllEqual({RM_VALS: {RM_VALS: 3.0}}, encoded_x)
    def test_decode_needs_input_shape(self):
        """Tests that mechanism for passing input shape works."""
        x_fn = lambda: tf.reshape(list(range(15)), [3, 5])
        encoder = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.ReduceMeanEncodingStage()).make(),
            tf.TensorSpec.from_tensor(x_fn()))

        iteration = _make_iteration_function(encoder, x_fn, 1)
        data = self.evaluate(iteration(encoder.initial_state()))
        self.assertAllEqual([[7.0] * 5] * 3, data.decoded_x)
Пример #3
0
    def test_decode_needs_input_shape_unknown_input_shape(self):
        """Tests that encoder works with stages that need input shape for decode.

    This test chains two stages with this property, and provides an input with
    statically unknown shape information.
    """
        encoder = core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).add_parent(
                test_utils.ReduceMeanEncodingStage(), RM_VALS).make()
        x = test_utils.get_tensor_with_random_shape()
        encode_params, decode_params = encoder.get_params(
            encoder.initial_state())
        encoded_x, _, input_shapes = encoder.encode(x, encode_params)
        decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
        assert x.shape.as_list(
        )[0] is None  # Validate the premise of the test.
        x, decoded_x = self.evaluate([x, decoded_x])

        # Assert shape is correctly recovered, and finctionality is as expected.
        self.assertAllEqual(x.shape, decoded_x.shape)
        self.assertAllClose([x.mean()] * len(x), decoded_x)
Пример #4
0
    def test_decode_needs_input_shape_static(self):
        """Tests that mechanism for passing input shape works with static shape."""
        x = tf.reshape(list(range(15)), [3, 5])
        encoder = simple_encoder.SimpleEncoder(
            core_encoder.EncoderComposer(
                test_utils.ReduceMeanEncodingStage()).make(),
            tf.TensorSpec.from_tensor(x))

        state = encoder.initial_state()
        iteration = _make_iteration_function(encoder)
        _, _, decoded_x, _ = self.evaluate(iteration(x, state))
        self.assertAllEqual([[7.0] * 5] * 3, decoded_x)
  def test_decode_needs_input_shape_dynamic(self):
    """Tests that mechanism for passing input shape works with dynamic shape."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).make())

    x = test_utils.get_tensor_with_random_shape()
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    x, decoded_x = self.evaluate([x, decoded_x])
    self.assertAllEqual(x.shape, decoded_x.shape)
  def test_decode_needs_input_shape_static(self):
    """Tests that mechanism for passing input shape works with static shape."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).make())

    x = tf.reshape(list(range(15)), [3, 5])
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    decoded_x = self.evaluate(decoded_x)
    self.assertAllEqual([[7.0] * 5] * 3, decoded_x)
Пример #7
0
    def test_add_parent(self):
        encoder = core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).add_parent(
                test_utils.PlusOneEncodingStage(),
                P1_VALS).add_parent(test_utils.TimesTwoEncodingStage(),
                                    T2_VALS).make()

        self.assertIsInstance(encoder, core_encoder.Encoder)
        self.assertIsInstance(encoder.stage._wrapped_stage,
                              test_utils.TimesTwoEncodingStage)
        self.assertIsInstance(encoder.children[T2_VALS], core_encoder.Encoder)
        self.assertIsInstance(encoder.children[T2_VALS].stage._wrapped_stage,
                              test_utils.PlusOneEncodingStage)
        self.assertIsInstance(encoder.children[T2_VALS].children[P1_VALS],
                              core_encoder.Encoder)
        self.assertIsInstance(
            encoder.children[T2_VALS].children[P1_VALS].stage._wrapped_stage,
            test_utils.ReduceMeanEncodingStage)
Пример #8
0
    def test_decode_needs_input_shape_dynamic(self):
        """Tests that mechanism for passing input shape works with dynamic shape."""
        if tf.executing_eagerly():
            fn = tf.function(test_utils.get_tensor_with_random_shape)
            tensorspec = tf.TensorSpec.from_tensor(
                fn.get_concrete_function().structured_outputs)
            x = fn()
        else:
            x = test_utils.get_tensor_with_random_shape()
            tensorspec = tf.TensorSpec.from_tensor(x)
        encoder = simple_encoder.SimpleEncoder(
            core_encoder.EncoderComposer(
                test_utils.ReduceMeanEncodingStage()).make(), tensorspec)

        # Validate the premise of the test - that encode mehtod expects an unknown
        # shape. This should be true both for graph and eager mode.
        assert (encoder._encode_fn.get_concrete_function().inputs[0].shape.
                as_list() == [None])

        state = encoder.initial_state()
        iteration = _make_iteration_function(encoder)
        x, _, decoded_x, _ = self.evaluate(iteration(x, state))
        self.assertAllEqual(x.shape, decoded_x.shape)
 def default_encoding_stage(self):
   """See base class."""
   return test_utils.ReduceMeanEncodingStage()