コード例 #1
0
    def test_input_types(self, x_dtype):
        # Tests different input dtypes.
        x = tf.constant([1.0, 0.1, 0.01, 0.001, 0.0001], dtype=x_dtype)
        threshold = 0.05
        stage = misc.SplitBySmallValueEncodingStage(threshold=threshold)
        encode_params, decode_params = stage.get_params()
        encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                                    decode_params)
        test_data = test_utils.TestData(x, encoded_x, decoded_x)
        test_data = self.evaluate_test_data(test_data)

        self._assert_is_integer(test_data.encoded_x[
            misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY])

        # The numpy arrays must have the same dtype as the arrays from test_data.
        expected_encoded_values = np.array([1.0, 0.1],
                                           dtype=x.dtype.as_numpy_dtype)
        expected_encoded_indices = np.array([[0], [1]], dtype=np.int32)
        expected_decoded_x = np.array([1.0, 0.1, 0., 0., 0.],
                                      dtype=x_dtype.as_numpy_dtype)
        self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY],
                            expected_encoded_values)
        self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY],
                            expected_encoded_indices)
        self.assertAllEqual(test_data.decoded_x, expected_decoded_x)
コード例 #2
0
    def test_all_zero_input_works(self):
        # Tests that encoding does not blow up with all-zero input. With all-zero
        # input, both of the encoded values will be empty arrays.
        stage = misc.SplitBySmallValueEncodingStage()
        test_data = self.run_one_to_many_encode_decode(stage,
                                                       lambda: tf.zeros([50]))

        self.assertAllEqual(
            np.zeros((50)).astype(np.float32), test_data.decoded_x)
コード例 #3
0
    def test_all_below_threshold_works(self):
        # Tests that encoding does not blow up with all-below-threshold input. In
        # this case, both of the encoded values will be empty arrays.
        stage = misc.SplitBySmallValueEncodingStage(threshold=0.1)
        x = tf.random.uniform([50], minval=-0.01, maxval=0.01)
        encode_params, decode_params = stage.get_params()
        encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                                    decode_params)
        test_data = test_utils.TestData(x, encoded_x, decoded_x)
        test_data = self.evaluate_test_data(test_data)

        expected_encoded_indices = np.array([], dtype=np.int32).reshape([0, 1])
        self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY], [])
        self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY],
                            expected_encoded_indices)
        self.assertAllEqual(test_data.decoded_x,
                            np.zeros([50], dtype=x.dtype.as_numpy_dtype))
コード例 #4
0
 def default_encoding_stage(self):
     """See base class."""
     return misc.SplitBySmallValueEncodingStage()