def test_move_dimension_static_shape(self): x = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) x_perm = distribution_util.move_dimension(x, 1, 1) self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 4, 1, 6]) x_perm = distribution_util.move_dimension(x, 0, 3) self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) x_perm = distribution_util.move_dimension(x, 0, -2) self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) x_perm = distribution_util.move_dimension(x, 4, 2) self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1])
def test_move_dimension_dynamic_shape(self): x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) x = array_ops.placeholder_with_default(input=x_, shape=None) x_perm = distribution_util.move_dimension(x, 1, 1) self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), [200, 30, 4, 1, 6]) x_perm = distribution_util.move_dimension(x, 0, 3) self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), [30, 4, 1, 200, 6]) x_perm = distribution_util.move_dimension(x, 0, -2) self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), [30, 4, 1, 200, 6]) x_perm = distribution_util.move_dimension(x, 4, 2) self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), [200, 30, 6, 4, 1]) x_perm = distribution_util.move_dimension(x, -1, 2) self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), [200, 30, 6, 4, 1])