예제 #1
0
    def test_numpy_extended_dot_2_a(self):
        m1 = numpy.arange(4).reshape((2, 2)).astype(numpy.float32) + 10
        m2 = m1 + 90

        self.assertRaise(lambda: numpy_extended_dot(m1, m2.T, [0], [1], [2]),
                         ValueError)
        dm1 = m1.reshape((2, 2, 1))
        dm2 = m2.reshape((1, 2, 2))
        dot = numpy_extended_dot(dm1, dm2, axes=[1], left=[0], right=[2])
        exp = m1 @ m2
        self.assertEqualArray(exp, numpy.squeeze(dot))
        dot2 = numpy_extended_dot_python(dm1,
                                         dm2,
                                         axes=[1],
                                         left=[0],
                                         right=[2])
        self.assertEqualArray(exp, numpy.squeeze(dot2))

        dm1 = m1.reshape((2, 1, 2))
        dm2 = m2.reshape((1, 2, 2))
        dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1])
        exp = m1 @ m2.T
        self.assertEqualArray(exp, numpy.squeeze(dot))
        dot2 = numpy_extended_dot_python(dm1,
                                         dm2,
                                         axes=[2],
                                         left=[0],
                                         right=[1])
        self.assertEqualArray(exp, numpy.squeeze(dot2))
예제 #2
0
    def test_numpy_extended_dot_3(self):
        m1 = numpy.arange(8).reshape((2, 2, 2)) + 10
        m2 = m1 + 90

        dot = numpy_extended_dot(m1, m2, [1], [0], [2])
        dot2 = numpy_extended_dot_python(m1, m2, [1], [0], [2])
        self.assertEqualArray(dot, dot2)

        dot = numpy_extended_dot(m1, m2, [1], [2], [0])
        dot2 = numpy_extended_dot_python(m1, m2, [1], [2], [0])
        self.assertEqualArray(dot, dot2)
예제 #3
0
 def test_numpy_extended_dot_2_b2(self):
     m1 = numpy.arange(4).reshape((2, 2)).astype(numpy.float32) + 10
     m2 = m1 + 90
     dm1 = m1.reshape((2, 2, 1))
     dm2 = m2.reshape((1, 2, 2))
     dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0, 1], right=[2])
     dot2 = numpy_extended_dot_python(dm1,
                                      dm2,
                                      axes=[2],
                                      left=[0, 1],
                                      right=[2])
     self.assertEqualArray(dot, numpy.squeeze(dot2))