def test_input_types(self, x_dtype): # Tests different input dtypes. x = tf.constant([1.0, 0.1, 0.01, 0.001, 0.0001], dtype=x_dtype) threshold = 0.05 stage = misc.SplitBySmallValueEncodingStage(threshold=threshold) encode_params, decode_params = stage.get_params() encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, decode_params) test_data = test_utils.TestData(x, encoded_x, decoded_x) test_data = self.evaluate_test_data(test_data) self._assert_is_integer(test_data.encoded_x[ misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY]) # The numpy arrays must have the same dtype as the arrays from test_data. expected_encoded_values = np.array([1.0, 0.1], dtype=x.dtype.as_numpy_dtype) expected_encoded_indices = np.array([[0], [1]], dtype=np.int32) expected_decoded_x = np.array([1.0, 0.1, 0., 0., 0.], dtype=x_dtype.as_numpy_dtype) self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY], expected_encoded_values) self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY], expected_encoded_indices) self.assertAllEqual(test_data.decoded_x, expected_decoded_x)
def test_all_zero_input_works(self): # Tests that encoding does not blow up with all-zero input. With all-zero # input, both of the encoded values will be empty arrays. stage = misc.SplitBySmallValueEncodingStage() test_data = self.run_one_to_many_encode_decode(stage, lambda: tf.zeros([50])) self.assertAllEqual( np.zeros((50)).astype(np.float32), test_data.decoded_x)
def test_all_below_threshold_works(self): # Tests that encoding does not blow up with all-below-threshold input. In # this case, both of the encoded values will be empty arrays. stage = misc.SplitBySmallValueEncodingStage(threshold=0.1) x = tf.random.uniform([50], minval=-0.01, maxval=0.01) encode_params, decode_params = stage.get_params() encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, decode_params) test_data = test_utils.TestData(x, encoded_x, decoded_x) test_data = self.evaluate_test_data(test_data) expected_encoded_indices = np.array([], dtype=np.int32).reshape([0, 1]) self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY], []) self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY], expected_encoded_indices) self.assertAllEqual(test_data.decoded_x, np.zeros([50], dtype=x.dtype.as_numpy_dtype))
def default_encoding_stage(self): """See base class.""" return misc.SplitBySmallValueEncodingStage()