예제 #1
0
    def testPadWithOneReflectionIsCorrect(self):
        """Tests that pad_reflecting(p) matches tf.pad(p) when p is small."""
        for _ in range(4):
            n = int(np.ceil(np.random.uniform() * 8)) + 1
            x = np.random.uniform(size=(n, n, n))
            padding_below = int(np.round(np.random.uniform() * (n - 1)))
            padding_above = int(np.round(np.random.uniform() * (n - 1)))
            axis = int(np.floor(np.random.uniform() * 3.))

            if axis == 0:
                reference = tf.pad(
                    x, [[padding_below, padding_above], [0, 0], [0, 0]],
                    'REFLECT')
            elif axis == 1:
                reference = tf.pad(
                    x, [[0, 0], [padding_below, padding_above], [0, 0]],
                    'REFLECT')
            elif axis == 2:
                reference = tf.pad(
                    x, [[0, 0], [0, 0], [padding_below, padding_above]],
                    'REFLECT')

            with self.session():
                result = wavelet.pad_reflecting(x, padding_below,
                                                padding_above, axis).eval()
                reference = reference.eval()
            self.assertAllEqual(result.shape, reference.shape)
            self.assertAllEqual(result, reference)
예제 #2
0
 def testPadWithManyReflectionsIsCorrect(self):
   """Tests that pad_reflecting(k * p) matches tf.pad(p) applied k times."""
   for _ in range(4):
     n = int(np.random.uniform() * 8.) + 1
     p = n - 1
     x = np.random.uniform(size=(n))
     result1 = wavelet.pad_reflecting(x, p, p, 0)
     result2 = wavelet.pad_reflecting(x, 2 * p, 2 * p, 0)
     result3 = wavelet.pad_reflecting(x, 3 * p, 3 * p, 0)
     reference1 = tf.pad(x, [[p, p]], mode='REFLECT')
     reference2 = tf.pad(reference1, [[p, p]], mode='REFLECT')
     reference3 = tf.pad(reference2, [[p, p]], mode='REFLECT')
     self.assertAllEqual(result1.shape, reference1.shape)
     self.assertAllEqual(result1, reference1)
     self.assertAllEqual(result2.shape, reference2.shape)
     self.assertAllEqual(result2, reference2)
     self.assertAllEqual(result3.shape, reference3.shape)
     self.assertAllEqual(result3, reference3)
예제 #3
0
 def testPadWithManyReflectionsGolden2IsCorrect(self):
     """Tests pad_reflecting() against a golden example."""
     n = 11
     p0 = 15
     p1 = 7
     x = np.arange(n)
     reference1 = np.concatenate((np.arange(5, n), np.arange(n - 2, 0, -1),
                                  np.arange(n), np.arange(n - 2, 2, -1)))
     with self.session():
         result1 = wavelet.pad_reflecting(x, p0, p1, 0).eval()
     self.assertAllEqual(result1.shape, reference1.shape)
     self.assertAllEqual(result1, reference1)
예제 #4
0
 def testPadWithManyReflectionsGolden1IsCorrect(self):
     """Tests pad_reflecting() against a golden example."""
     n = 8
     p0 = 17
     p1 = 13
     x = np.arange(n)
     reference1 = np.concatenate(
         (np.arange(3, 0, -1), np.arange(n), np.arange(n - 2, 0, -1),
          np.arange(n), np.arange(n - 2, 0,
                                  -1), np.arange(7)))  # pyformat: disable
     result1 = wavelet.pad_reflecting(x, p0, p1, 0)
     self.assertAllEqual(result1.shape, reference1.shape)
     self.assertAllEqual(result1, reference1)