def testPadWithOneReflectionIsCorrect(self): """Tests that pad_reflecting(p) matches tf.pad(p) when p is small.""" for _ in range(4): n = int(np.ceil(np.random.uniform() * 8)) + 1 x = np.random.uniform(size=(n, n, n)) padding_below = int(np.round(np.random.uniform() * (n - 1))) padding_above = int(np.round(np.random.uniform() * (n - 1))) axis = int(np.floor(np.random.uniform() * 3.)) if axis == 0: reference = tf.pad( x, [[padding_below, padding_above], [0, 0], [0, 0]], 'REFLECT') elif axis == 1: reference = tf.pad( x, [[0, 0], [padding_below, padding_above], [0, 0]], 'REFLECT') elif axis == 2: reference = tf.pad( x, [[0, 0], [0, 0], [padding_below, padding_above]], 'REFLECT') with self.session(): result = wavelet.pad_reflecting(x, padding_below, padding_above, axis).eval() reference = reference.eval() self.assertAllEqual(result.shape, reference.shape) self.assertAllEqual(result, reference)
def testPadWithManyReflectionsIsCorrect(self): """Tests that pad_reflecting(k * p) matches tf.pad(p) applied k times.""" for _ in range(4): n = int(np.random.uniform() * 8.) + 1 p = n - 1 x = np.random.uniform(size=(n)) result1 = wavelet.pad_reflecting(x, p, p, 0) result2 = wavelet.pad_reflecting(x, 2 * p, 2 * p, 0) result3 = wavelet.pad_reflecting(x, 3 * p, 3 * p, 0) reference1 = tf.pad(x, [[p, p]], mode='REFLECT') reference2 = tf.pad(reference1, [[p, p]], mode='REFLECT') reference3 = tf.pad(reference2, [[p, p]], mode='REFLECT') self.assertAllEqual(result1.shape, reference1.shape) self.assertAllEqual(result1, reference1) self.assertAllEqual(result2.shape, reference2.shape) self.assertAllEqual(result2, reference2) self.assertAllEqual(result3.shape, reference3.shape) self.assertAllEqual(result3, reference3)
def testPadWithManyReflectionsGolden2IsCorrect(self): """Tests pad_reflecting() against a golden example.""" n = 11 p0 = 15 p1 = 7 x = np.arange(n) reference1 = np.concatenate((np.arange(5, n), np.arange(n - 2, 0, -1), np.arange(n), np.arange(n - 2, 2, -1))) with self.session(): result1 = wavelet.pad_reflecting(x, p0, p1, 0).eval() self.assertAllEqual(result1.shape, reference1.shape) self.assertAllEqual(result1, reference1)
def testPadWithManyReflectionsGolden1IsCorrect(self): """Tests pad_reflecting() against a golden example.""" n = 8 p0 = 17 p1 = 13 x = np.arange(n) reference1 = np.concatenate( (np.arange(3, 0, -1), np.arange(n), np.arange(n - 2, 0, -1), np.arange(n), np.arange(n - 2, 0, -1), np.arange(7))) # pyformat: disable result1 = wavelet.pad_reflecting(x, p0, p1, 0) self.assertAllEqual(result1.shape, reference1.shape) self.assertAllEqual(result1, reference1)