コード例 #1
0
 def test_broadcast(self):
     for a, b in [[(1, 2), (45, 2)], [(1, ), (45, 2)], [(3, 1), (1, 3)],
                  [(3, 1), (1, )], [(3, 1), (1, 1)], [(1, 3), (3, 1)]]:
         sh1 = ShapeObject(a, dtype=numpy.float32)
         sh2 = ShapeObject(b, dtype=numpy.float32)
         ma = numpy.zeros(a)
         mb = numpy.zeros(b)
         mx = sh1.broadcast(sh2)
         mc = ma + mb
         self.assertEqual(mx, mc.shape)
コード例 #2
0
 def _infer_shapes(self, x):  # pylint: disable=W0221
     if self.eigv:  # pylint: disable=E1101
         return (ShapeObject(x.shape,
                             dtype=x.dtype,
                             name=self.__class__.__name__ + 'Values'),
                 ShapeObject(x.shape,
                             dtype=x.dtype,
                             name=self.__class__.__name__ + 'Vectors'))
     return (ShapeObject(x.shape,
                         dtype=x.dtype,
                         name=self.__class__.__name__), )
コード例 #3
0
 def infer_shapes(self, x):
     # shape inference, if you don't know what to
     # write, just return `ShapeObject(None)`
     if self.eigv:
         return (
             ShapeObject(
                 x.shape, dtype=x.dtype,
                 name=self.__class__.__name__ + 'Values'),
             ShapeObject(
                 x.shape, dtype=x.dtype,
                 name=self.__class__.__name__ + 'Vectors'))
     return (ShapeObject(x.shape, dtype=x.dtype,
                         name=self.__class__.__name__), )
コード例 #4
0
 def test_shape_object_max(self):
     sh1 = ShapeObject((1, 2, 3), dtype=numpy.float32)
     sh2 = ShapeObject((1, 2), dtype=numpy.float32)
     sh = max(sh1, sh2)
     self.assertEqual(repr(sh),
                      "ShapeObject((1, 2, 3), dtype=numpy.float32)")
     sh = max(sh2, sh1)
     self.assertEqual(repr(sh),
                      "ShapeObject((1, 2, 3), dtype=numpy.float32)")
     sh1 = ShapeObject((1, 2, 3), dtype=numpy.float32)
     sh2 = ShapeObject((1, 2, 3), dtype=numpy.float32)
     sh = max(sh2, sh1)
     self.assertEqual(repr(sh),
                      "ShapeObject((1, 2, 3), dtype=numpy.float32)")
コード例 #5
0
 def test_shape_object(self):
     self.assertRaise(lambda: ShapeObject((1, 2, 3)), ValueError)
     sh = ShapeObject((1, 2, 3), dtype=numpy.float32)
     self.assertEqual(repr(sh),
                      "ShapeObject((1, 2, 3), dtype=numpy.float32)")
     red = sh.reduce(0)
     self.assertTrue(red == (2, 3))
     self.assertRaise(lambda: sh.reduce(10), IndexError)
     red = sh.reduce(1, True)
     self.assertTrue(red == (1, 1, 3))
コード例 #6
0
 def test_shape_object_reshape(self):
     sh = ShapeObject((1, 2, 3), dtype=numpy.float32)
     sk = sh.reshape((6, 1, 1))
     self.assertEqual(sk, (6, 1, 1))
     self.assertRaise(lambda: sh.reshape((9, 1, 1)))
コード例 #7
0
 def test_max(self):
     sh1 = ShapeObject((1, 2), dtype=numpy.float32)
     sh2 = ShapeObject((45, 2), dtype=numpy.float32)
     mx = max(sh1, sh2)
     self.assertEqual(mx, (45, 2))
コード例 #8
0
 def test_maximum_none(self):
     i1 = ShapeObject((1, ), dtype=numpy.float32, name="A")
     i2 = ShapeObject(None, dtype=numpy.float32, name="B")
     i3 = max(i1, i2)
     self.assertEqual(i3.name, 'B')