Esempio n. 1
0
 def _kashin_forward(self, x, signs, clip_level, clip):
     """Forward step of the algorithm to obtain Kashin's representation."""
     x = x * signs
     x = self._pad(x)
     x = tf_utils.fast_walsh_hadamard_transform(x)
     if clip:
         x = tf.clip_by_value(x, -clip_level, clip_level)
     return x
 def test_illegal_inputs_dynamic_power_of_two(self):
     """Tests incorrect dynamic shape of the rank 2 input."""
     rand = tf.random.uniform((), maxval=3, dtype=tf.int32)
     x = tf.random.normal((3, 3**rand))
     hx = tf_utils.fast_walsh_hadamard_transform(x)
     with self.assertRaisesOpError(
             'The dimension of x must be a power of two.'):
         hx = self.evaluate(hx)
 def test_is_rotation(self, dim):
     """Tests the transform acts as a rotation."""
     x = tf.random.normal([1, dim])
     hx = tf_utils.fast_walsh_hadamard_transform(x)
     x, hx = self.evaluate([x, hx])
     # Check that x and hx are not the same, but have equal norm.
     self.assertGreater(np.linalg.norm(x - hx), 1e-3)
     self.assertAllClose(np.linalg.norm(x), np.linalg.norm(hx))
 def encode(self, x, encode_params):
   """See base class."""
   x = self._validate_and_expand_encode_input(x)
   signs = self._random_signs(x.shape.as_list()[1],
                              encode_params[self.SEED_PARAMS_KEY], x.dtype)
   x = x * signs
   x = self._pad(x)
   rotated_x = tf_utils.fast_walsh_hadamard_transform(x)
   return {self.ENCODED_VALUES_KEY: rotated_x}
Esempio n. 5
0
 def _kashin_backward(self, x, shape, signs=None):
     """Backward step of the algorithm to obtain Kashin's representation."""
     x = tf_utils.fast_walsh_hadamard_transform(x)
     # Take the slice corresponding to the original object that was encoded.
     # Consistency in specific coordinates for padding and slicing is what makes
     # inverse transformation unique.
     x = tf.slice(x, [0, 0], [tf.shape(x)[0], shape])
     if signs is not None:
         x = x * signs
     return x
 def test_illegal_inputs_dynamic_power_of_two(self):
     """Tests incorrect dynamic shape of the rank 2 input."""
     rand = tf.random.uniform((), maxval=3, dtype=tf.int32) + 1
     # The created x has shape (3, 3) or (3, 9) or (3, 27), chosen randomly and
     # thus statically not known. In all cases, it is not a power of two.
     x = tf.random.normal((3, 3**rand))
     hx = tf_utils.fast_walsh_hadamard_transform(x)
     with self.assertRaisesOpError(
             'The dimension of x must be a power of two.'):
         hx = self.evaluate(hx)
Esempio n. 7
0
    def decode(self,
               encoded_tensors,
               decode_params,
               num_summands=None,
               shape=None):
        """See base class."""
        del num_summands  # Unused.
        rotated_x = encoded_tensors[self.ENCODED_VALUES_KEY]
        unrotated_x = tf_utils.fast_walsh_hadamard_transform(rotated_x)

        # Take slice corresponding to the input shape.
        decoded_x = tf.slice(unrotated_x, [0, 0],
                             [tf.shape(unrotated_x)[0], shape[-1]])
        signs = self._random_signs(decoded_x.shape.as_list()[-1],
                                   decode_params[self.SEED_PARAMS_KEY],
                                   decoded_x.dtype)
        decoded_x = decoded_x * signs
        if shape.shape.num_elements() == 1:
            decoded_x = tf.squeeze(decoded_x, [0])
        return decoded_x
Esempio n. 8
0
 def test_illegal_inputs_static_power_of_two(self, *dims):
     """Tests incorrect static shape of the rank 2 input."""
     x = tf.random.normal(dims)
     with self.assertRaisesRegexp(
             ValueError, 'The dimension of x must be a power of two.'):
         tf_utils.fast_walsh_hadamard_transform(x)
Esempio n. 9
0
 def test_illegal_inputs_shape(self, *dims):
     """Tests incorrect rank of the input."""
     x = tf.random.normal(dims)
     with self.assertRaisesRegexp(ValueError,
                                  'Number of dimensions of x must be 2.'):
         tf_utils.fast_walsh_hadamard_transform(x)