コード例 #1
0
    def test_decode_needs_input_shape(self):
        """Tests that encoder works with stages that need input shape for decode.

    This test chains two stages with this property.
    """
        encoder = core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).add_parent(
                test_utils.ReduceMeanEncodingStage(), RM_VALS).make()
        x = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0])
        encode_params, decode_params = encoder.get_params(
            encoder.initial_state())
        encoded_x, _, input_shapes = encoder.encode(x, encode_params)
        decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
        encoded_x, decoded_x = self.evaluate([encoded_x, decoded_x])

        self.assertAllEqual([3.0] * 5, decoded_x)
        self.assertAllEqual({RM_VALS: {RM_VALS: 3.0}}, encoded_x)
コード例 #2
0
    def test_decode_needs_input_shape(self):
        """Tests that mechanism for passing input shape works."""
        x_fn = lambda: tf.reshape(list(range(15)), [3, 5])
        encoder = gather_encoder.GatherEncoder.from_encoder(
            core_encoder.EncoderComposer(
                test_utils.ReduceMeanEncodingStage()).make(),
            tf.TensorSpec.from_tensor(x_fn()))

        iteration = _make_iteration_function(encoder, x_fn, 1)
        data = self.evaluate(iteration(encoder.initial_state()))
        self.assertAllEqual([[7.0] * 5] * 3, data.decoded_x)
コード例 #3
0
    def test_decode_needs_input_shape_unknown_input_shape(self):
        """Tests that encoder works with stages that need input shape for decode.

    This test chains two stages with this property, and provides an input with
    statically unknown shape information.
    """
        encoder = core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).add_parent(
                test_utils.ReduceMeanEncodingStage(), RM_VALS).make()
        x = test_utils.get_tensor_with_random_shape()
        encode_params, decode_params = encoder.get_params(
            encoder.initial_state())
        encoded_x, _, input_shapes = encoder.encode(x, encode_params)
        decoded_x = encoder.decode(encoded_x, decode_params, input_shapes)
        assert x.shape.as_list(
        )[0] is None  # Validate the premise of the test.
        x, decoded_x = self.evaluate([x, decoded_x])

        # Assert shape is correctly recovered, and finctionality is as expected.
        self.assertAllEqual(x.shape, decoded_x.shape)
        self.assertAllClose([x.mean()] * len(x), decoded_x)
コード例 #4
0
    def test_decode_needs_input_shape_static(self):
        """Tests that mechanism for passing input shape works with static shape."""
        x = tf.reshape(list(range(15)), [3, 5])
        encoder = simple_encoder.SimpleEncoder(
            core_encoder.EncoderComposer(
                test_utils.ReduceMeanEncodingStage()).make(),
            tf.TensorSpec.from_tensor(x))

        state = encoder.initial_state()
        iteration = _make_iteration_function(encoder)
        _, _, decoded_x, _ = self.evaluate(iteration(x, state))
        self.assertAllEqual([[7.0] * 5] * 3, decoded_x)
コード例 #5
0
  def test_decode_needs_input_shape_dynamic(self):
    """Tests that mechanism for passing input shape works with dynamic shape."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).make())

    x = test_utils.get_tensor_with_random_shape()
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    x, decoded_x = self.evaluate([x, decoded_x])
    self.assertAllEqual(x.shape, decoded_x.shape)
コード例 #6
0
  def test_decode_needs_input_shape_static(self):
    """Tests that mechanism for passing input shape works with static shape."""
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).make())

    x = tf.reshape(list(range(15)), [3, 5])
    encoded_x, decode_fn = encoder.encode(x)
    decoded_x = decode_fn(encoded_x)

    decoded_x = self.evaluate(decoded_x)
    self.assertAllEqual([[7.0] * 5] * 3, decoded_x)
コード例 #7
0
    def test_add_parent(self):
        encoder = core_encoder.EncoderComposer(
            test_utils.ReduceMeanEncodingStage()).add_parent(
                test_utils.PlusOneEncodingStage(),
                P1_VALS).add_parent(test_utils.TimesTwoEncodingStage(),
                                    T2_VALS).make()

        self.assertIsInstance(encoder, core_encoder.Encoder)
        self.assertIsInstance(encoder.stage._wrapped_stage,
                              test_utils.TimesTwoEncodingStage)
        self.assertIsInstance(encoder.children[T2_VALS], core_encoder.Encoder)
        self.assertIsInstance(encoder.children[T2_VALS].stage._wrapped_stage,
                              test_utils.PlusOneEncodingStage)
        self.assertIsInstance(encoder.children[T2_VALS].children[P1_VALS],
                              core_encoder.Encoder)
        self.assertIsInstance(
            encoder.children[T2_VALS].children[P1_VALS].stage._wrapped_stage,
            test_utils.ReduceMeanEncodingStage)
コード例 #8
0
    def test_decode_needs_input_shape_dynamic(self):
        """Tests that mechanism for passing input shape works with dynamic shape."""
        if tf.executing_eagerly():
            fn = tf.function(test_utils.get_tensor_with_random_shape)
            tensorspec = tf.TensorSpec.from_tensor(
                fn.get_concrete_function().structured_outputs)
            x = fn()
        else:
            x = test_utils.get_tensor_with_random_shape()
            tensorspec = tf.TensorSpec.from_tensor(x)
        encoder = simple_encoder.SimpleEncoder(
            core_encoder.EncoderComposer(
                test_utils.ReduceMeanEncodingStage()).make(), tensorspec)

        # Validate the premise of the test - that encode mehtod expects an unknown
        # shape. This should be true both for graph and eager mode.
        assert (encoder._encode_fn.get_concrete_function().inputs[0].shape.
                as_list() == [None])

        state = encoder.initial_state()
        iteration = _make_iteration_function(encoder)
        x, _, decoded_x, _ = self.evaluate(iteration(x, state))
        self.assertAllEqual(x.shape, decoded_x.shape)
コード例 #9
0
 def default_encoding_stage(self):
   """See base class."""
   return test_utils.ReduceMeanEncodingStage()