def _kashin_forward(self, x, signs, clip_level, clip): """Forward step of the algorithm to obtain Kashin's representation.""" x = x * signs x = self._pad(x) x = tf_utils.fast_walsh_hadamard_transform(x) if clip: x = tf.clip_by_value(x, -clip_level, clip_level) return x
def test_illegal_inputs_dynamic_power_of_two(self): """Tests incorrect dynamic shape of the rank 2 input.""" rand = tf.random.uniform((), maxval=3, dtype=tf.int32) x = tf.random.normal((3, 3**rand)) hx = tf_utils.fast_walsh_hadamard_transform(x) with self.assertRaisesOpError( 'The dimension of x must be a power of two.'): hx = self.evaluate(hx)
def test_is_rotation(self, dim): """Tests the transform acts as a rotation.""" x = tf.random.normal([1, dim]) hx = tf_utils.fast_walsh_hadamard_transform(x) x, hx = self.evaluate([x, hx]) # Check that x and hx are not the same, but have equal norm. self.assertGreater(np.linalg.norm(x - hx), 1e-3) self.assertAllClose(np.linalg.norm(x), np.linalg.norm(hx))
def encode(self, x, encode_params): """See base class.""" x = self._validate_and_expand_encode_input(x) signs = self._random_signs(x.shape.as_list()[1], encode_params[self.SEED_PARAMS_KEY], x.dtype) x = x * signs x = self._pad(x) rotated_x = tf_utils.fast_walsh_hadamard_transform(x) return {self.ENCODED_VALUES_KEY: rotated_x}
def _kashin_backward(self, x, shape, signs=None): """Backward step of the algorithm to obtain Kashin's representation.""" x = tf_utils.fast_walsh_hadamard_transform(x) # Take the slice corresponding to the original object that was encoded. # Consistency in specific coordinates for padding and slicing is what makes # inverse transformation unique. x = tf.slice(x, [0, 0], [tf.shape(x)[0], shape]) if signs is not None: x = x * signs return x
def test_illegal_inputs_dynamic_power_of_two(self): """Tests incorrect dynamic shape of the rank 2 input.""" rand = tf.random.uniform((), maxval=3, dtype=tf.int32) + 1 # The created x has shape (3, 3) or (3, 9) or (3, 27), chosen randomly and # thus statically not known. In all cases, it is not a power of two. x = tf.random.normal((3, 3**rand)) hx = tf_utils.fast_walsh_hadamard_transform(x) with self.assertRaisesOpError( 'The dimension of x must be a power of two.'): hx = self.evaluate(hx)
def decode(self, encoded_tensors, decode_params, num_summands=None, shape=None): """See base class.""" del num_summands # Unused. rotated_x = encoded_tensors[self.ENCODED_VALUES_KEY] unrotated_x = tf_utils.fast_walsh_hadamard_transform(rotated_x) # Take slice corresponding to the input shape. decoded_x = tf.slice(unrotated_x, [0, 0], [tf.shape(unrotated_x)[0], shape[-1]]) signs = self._random_signs(decoded_x.shape.as_list()[-1], decode_params[self.SEED_PARAMS_KEY], decoded_x.dtype) decoded_x = decoded_x * signs if shape.shape.num_elements() == 1: decoded_x = tf.squeeze(decoded_x, [0]) return decoded_x
def test_illegal_inputs_static_power_of_two(self, *dims): """Tests incorrect static shape of the rank 2 input.""" x = tf.random.normal(dims) with self.assertRaisesRegexp( ValueError, 'The dimension of x must be a power of two.'): tf_utils.fast_walsh_hadamard_transform(x)
def test_illegal_inputs_shape(self, *dims): """Tests incorrect rank of the input.""" x = tf.random.normal(dims) with self.assertRaisesRegexp(ValueError, 'Number of dimensions of x must be 2.'): tf_utils.fast_walsh_hadamard_transform(x)