def test_broadcast(self): for a, b in [[(1, 2), (45, 2)], [(1, ), (45, 2)], [(3, 1), (1, 3)], [(3, 1), (1, )], [(3, 1), (1, 1)], [(1, 3), (3, 1)]]: sh1 = ShapeObject(a, dtype=numpy.float32) sh2 = ShapeObject(b, dtype=numpy.float32) ma = numpy.zeros(a) mb = numpy.zeros(b) mx = sh1.broadcast(sh2) mc = ma + mb self.assertEqual(mx, mc.shape)
def _infer_shapes(self, x): # pylint: disable=W0221 if self.eigv: # pylint: disable=E1101 return (ShapeObject(x.shape, dtype=x.dtype, name=self.__class__.__name__ + 'Values'), ShapeObject(x.shape, dtype=x.dtype, name=self.__class__.__name__ + 'Vectors')) return (ShapeObject(x.shape, dtype=x.dtype, name=self.__class__.__name__), )
def infer_shapes(self, x): # shape inference, if you don't know what to # write, just return `ShapeObject(None)` if self.eigv: return ( ShapeObject( x.shape, dtype=x.dtype, name=self.__class__.__name__ + 'Values'), ShapeObject( x.shape, dtype=x.dtype, name=self.__class__.__name__ + 'Vectors')) return (ShapeObject(x.shape, dtype=x.dtype, name=self.__class__.__name__), )
def test_shape_object_max(self): sh1 = ShapeObject((1, 2, 3), dtype=numpy.float32) sh2 = ShapeObject((1, 2), dtype=numpy.float32) sh = max(sh1, sh2) self.assertEqual(repr(sh), "ShapeObject((1, 2, 3), dtype=numpy.float32)") sh = max(sh2, sh1) self.assertEqual(repr(sh), "ShapeObject((1, 2, 3), dtype=numpy.float32)") sh1 = ShapeObject((1, 2, 3), dtype=numpy.float32) sh2 = ShapeObject((1, 2, 3), dtype=numpy.float32) sh = max(sh2, sh1) self.assertEqual(repr(sh), "ShapeObject((1, 2, 3), dtype=numpy.float32)")
def test_shape_object(self): self.assertRaise(lambda: ShapeObject((1, 2, 3)), ValueError) sh = ShapeObject((1, 2, 3), dtype=numpy.float32) self.assertEqual(repr(sh), "ShapeObject((1, 2, 3), dtype=numpy.float32)") red = sh.reduce(0) self.assertTrue(red == (2, 3)) self.assertRaise(lambda: sh.reduce(10), IndexError) red = sh.reduce(1, True) self.assertTrue(red == (1, 1, 3))
def test_shape_object_reshape(self): sh = ShapeObject((1, 2, 3), dtype=numpy.float32) sk = sh.reshape((6, 1, 1)) self.assertEqual(sk, (6, 1, 1)) self.assertRaise(lambda: sh.reshape((9, 1, 1)))
def test_max(self): sh1 = ShapeObject((1, 2), dtype=numpy.float32) sh2 = ShapeObject((45, 2), dtype=numpy.float32) mx = max(sh1, sh2) self.assertEqual(mx, (45, 2))
def test_maximum_none(self): i1 = ShapeObject((1, ), dtype=numpy.float32, name="A") i2 = ShapeObject(None, dtype=numpy.float32, name="B") i3 = max(i1, i2) self.assertEqual(i3.name, 'B')