コード例 #1
0
    def test_elemwise_float64(self):
        old_floatX = theano.config.floatX
        theano.config.floatX = 'float64'

        a = theano.tensor.dtensor4('a')
        b = theano.tensor.dtensor4('b')
        c = theano.tensor.dtensor4('c')

        a_internal = basic_ops.U2IElemwiseSum(inp_num=3, coeff=[1.0, 1.0,
                                                                1.0])(a)
        b_internal = basic_ops.U2IElemwiseSum(inp_num=3, coeff=[1.0, 1.0,
                                                                1.0])(b)
        c_internal = basic_ops.U2IElemwiseSum(inp_num=3, coeff=[1.0, 1.0,
                                                                1.0])(c)

        z_internal = mkl_elemwise.ElemwiseSum(inp_num=3,
                                              coeff=[1.0, 1.0,
                                                     1.0])(a_internal,
                                                           b_internal,
                                                           c_internal)
        z = basic_ops.I2U()(z_internal)
        f = theano.function([a, b, c], z)

        ival0 = numpy.random.rand(4, 4, 4, 4).astype(theano.config.floatX)
        ival1 = numpy.random.rand(4, 4, 4, 4).astype(theano.config.floatX)
        ival2 = numpy.random.rand(4, 4, 4, 4).astype(theano.config.floatX)
        assert numpy.allclose(f(ival0, ival1, ival2), ival0 + ival1 + ival2)
        assert f(ival0, ival1, ival2).dtype == 'float64'
        theano.config.floatX = old_floatX
コード例 #2
0
 def test_elemwise_U2I(self):
     a = theano.tensor.ftensor4('a')
     a_internal = basic_ops.U2IElemwiseSum(inp_num=1, coeff=[
         1.0,
     ])(a)
     a_out = basic_ops.I2U()(a_internal)
     f = theano.function([a], a_out)
     ival = numpy.random.rand(4, 4, 4, 4).astype(numpy.float32)
     assert numpy.allclose(f(ival), ival)
コード例 #3
0
 def test_elemwise_wrong_dim(self):
     a = theano.tensor.fmatrix('a')
     try:
         basic_ops.U2IElemwiseSum(inp_num=1, coeff=[
             1.0,
         ])(a)
         raise Exception('No Exception when ndim is 2.')
     except TypeError:
         pass
     except Exception as e:
         raise Exception('test_elemwise_wrong_dim ' + str(e))
コード例 #4
0
    def test_elemwise_value(self):
        a = theano.tensor.ftensor4('a')
        b = theano.tensor.ftensor4('b')
        c = theano.tensor.ftensor4('c')

        a_internal = basic_ops.U2IElemwiseSum(inp_num=3, coeff=[1.0, 1.0,
                                                                1.0])(a)
        b_internal = basic_ops.U2IElemwiseSum(inp_num=3, coeff=[1.0, 1.0,
                                                                1.0])(b)
        c_internal = basic_ops.U2IElemwiseSum(inp_num=3, coeff=[1.0, 1.0,
                                                                1.0])(c)

        z_internal = mkl_elemwise.ElemwiseSum(inp_num=3,
                                              coeff=[1.0, 1.0,
                                                     1.0])(a_internal,
                                                           b_internal,
                                                           c_internal)
        z = basic_ops.I2U()(z_internal)
        f = theano.function([a, b, c], z)

        ival0 = numpy.random.rand(4, 4, 4, 4).astype(numpy.float32)
        ival1 = numpy.random.rand(4, 4, 4, 4).astype(numpy.float32)
        ival2 = numpy.random.rand(4, 4, 4, 4).astype(numpy.float32)
        assert numpy.allclose(f(ival0, ival1, ival2), ival0 + ival1 + ival2)
コード例 #5
0
    def test_elemwise_input_num(self):
        try:
            basic_ops.U2IElemwiseSum(inp_num=3, coeff=[1.0, 1.0])
            raise Exception(
                'U2IElemwiseSUm No Exception when inp_num != len(coeff)')
        except ValueError:
            pass
        except Exception as e:
            raise Exception('test_elemwise_input_num ' + str(e))

        try:
            mkl_elemwise.ElemwiseSum(inp_num=3, coeff=[1.0, 1.0])
            raise Exception(
                'ElemwiseSum No Exception when inp_num != len(coeff)')
        except ValueError:
            pass
        except Exception as e:
            raise Exception('test_elemwise_input_num ' + str(e))