def mkl_pool_func(*inputs): mkl_ver = theano.sandbox.mkl.mkl_version() if inputs[2] and isinstance(mkl_ver, integer_types) and (mkl_ver < 20170206): raise SkipTest("Need newer MKL to support 'ignore_border=True'.") if len(inputs) == 5: # self, images, ignore_border, mode, ds _, images, ignore_border, mode, ds, = inputs x_internal = U2IPool(ignore_border=ignore_border, mode=mode)(images, ds) poolOut = Pool(ignore_border=ignore_border, mode=mode)(x_internal, ds) output = I2U()(poolOut) elif len(inputs) == 6: # self, images, ignore_border, mode, ds, st, _, images, ignore_border, mode, ds, st, = inputs x_internal = U2IPool(ignore_border=ignore_border, mode=mode)(images, ds, st) poolOut = Pool(ignore_border=ignore_border, mode=mode)(x_internal, ds, st) output = I2U()(poolOut) elif len(inputs) == 7: # self, images, ignore_border, mode, ds, st, pad _, images, ignore_border, mode, ds, st, pad = inputs x_internal = U2IPool(ignore_border=ignore_border, mode=mode)(images, ds, st, pad) poolOut = Pool(ignore_border=ignore_border, mode=mode)(x_internal, ds, st, pad) output = I2U()(poolOut) else: raise ValueError("incorrect inputs list, should be 4 ~ 6 parameters!") return output
def test_conv_with_bias(self): images = T.dtensor4('inputs') weights = T.dtensor4('weights') bias = T.dvector('bias') ishape = [(8, 3, 256, 256), (16, 3, 256, 256), (32, 3, 256, 256), (64, 3, 256, 256)] wshape = [(8, 3, 3, 3), (16, 3, 3, 3), (32, 3, 3, 3), (64, 3, 3, 3)] for i, ish in enumerate(ishape): wsh = wshape[i] images_internal = U2IConv(imshp=ish, kshp=wsh)(images) convOutBias_internal = Conv2D(imshp=ish, kshp=wsh, filter_flip=False)(images_internal, weights, bias) convOutBias_user = I2U()(convOutBias_internal) ival = numpy.random.rand(*ish).astype(numpy.float64) wval = numpy.random.rand(*wsh).astype(numpy.float64) bval = numpy.random.rand(wsh[0]).astype(numpy.float64) fopt = theano.function(inputs=[images, weights, bias], outputs=convOutBias_user, mode=mode_with_mkl) new_old = fopt(ival, wval, bval) convOut = conv2d(images, weights, input_shape=ish, filter_shape=wsh, filter_flip=False) convOutBias = convOut + bias.dimshuffle('x', 0, 'x', 'x') fori = theano.function(inputs=[images, weights, bias], outputs=convOutBias, mode=mode_without_mkl) old_out = fori(ival, wval, bval) assert str(fopt.maker.fgraph.toposort()) != str(fori.maker.fgraph.toposort()) assert numpy.allclose(old_out, new_old)
def local_lrn_mkl(node): if not mkl_available(): return if not isinstance(node.op, mkl_lrn.AbstractLRN): return if node.inputs[0].type.ndim != 4: return try: x, = node.inputs x_u2i = U2ILRN(alpha=node.op.alpha, beta=node.op.beta, k=node.op.k, n=node.op.n)(x) lrnout = mkl_lrn.LRN(alpha=node.op.alpha, beta=node.op.beta, k=node.op.k, n=node.op.n)(x_u2i) z_i2u = I2U()(lrnout) rval = z_i2u return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def test_conv_no_bias(self): images = T.dtensor4('input_conv') weights = T.dtensor4('weights') images_internal = U2IConv(imshp=(12, 3, 256, 256), kshp=(12, 3, 3, 3))(images) convOut = Conv2D(imshp=(12, 3, 256, 256), kshp=(12, 3, 3, 3), filter_flip=False)(images_internal, weights) convOut_user = I2U()(convOut) convOutLoss = T.mean(convOut_user) conv_op_di = T.grad(convOutLoss, images) conv_op_dk = T.grad(convOutLoss, weights) convOutBack = [conv_op_di, conv_op_dk] ival = numpy.random.rand(12, 3, 256, 256).astype(numpy.float64) wval = numpy.random.rand(12, 3, 3, 3).astype(numpy.float64) fopt = theano.function(inputs=[images, weights], outputs=convOutBack, mode=mode_with_mkl) new_out = fopt(ival, wval) convOut = conv2d(images, weights, input_shape=(12, 3, 256, 256), filter_shape=(12, 3, 3, 3), filter_flip=False) convOutLoss = T.mean(convOut) conv_op_di = T.grad(convOutLoss, images) conv_op_dk = T.grad(convOutLoss, weights) convOutBack = [conv_op_di, conv_op_dk] fori = theano.function(inputs=[images, weights], outputs=convOutBack, mode=mode_without_mkl) old_out = fori(ival, wval) assert len(fopt.maker.fgraph.toposort()) != len(fori.maker.fgraph.toposort()) assert numpy.allclose(old_out[0], new_out[0]) assert new_out[0].dtype == 'float64'
def test_conv_no_bias(self): images = T.dtensor4('inputs') weights = T.dtensor4('weights') images_internal = U2IConv(imshp=(12, 3, 256, 256), kshp=(12, 3, 3, 3))(images) convOut_internal = Conv2D(imshp=(12, 3, 256, 256), kshp=(12, 3, 3, 3), filter_flip=False)(images_internal, weights) convOut_user = I2U()(convOut_internal) ival = numpy.random.rand(12, 3, 256, 256).astype(numpy.float64) wval = numpy.random.rand(12, 3, 3, 3).astype(numpy.float64) fopt = theano.function(inputs=[images, weights], outputs=convOut_user, mode=mode_with_mkl) new_out = fopt(ival, wval) convOut = conv2d(images, weights, input_shape=(12, 3, 256, 256), filter_shape=(12, 3, 3, 3), filter_flip=False) fori = theano.function(inputs=[images, weights], outputs=convOut, mode=mode_without_mkl) old_out = fori(ival, wval) assert str(fopt.maker.fgraph.toposort()) != str( fori.maker.fgraph.toposort()) assert numpy.allclose(old_out, new_out)
def test_bn_U2I(self): x = T.ftensor4('x') x_internal = U2IBatchNormalization(eps=1e-5)(x) x_out = I2U()(x_internal) fopt = theano.function([x], x_out, mode=with_mkl) ival = numpy.random.rand(64, 5, 128, 128).astype(numpy.float32) assert numpy.allclose(fopt(ival), ival)
def test_conv_U2I(self): images = T.dtensor4('inputs') a_internal = U2IConv(imshp=(12, 3, 256, 256), kshp=(12, 3, 3, 3))(images) out = I2U()(a_internal) fopt = theano.function([images], out, mode=mode_with_mkl) ival = numpy.random.rand(12, 3, 256, 256).astype(numpy.float64) assert numpy.allclose(fopt(ival), ival)
def test_conv_with_bias(self): images = T.dtensor4('input_conv') weights = T.dtensor4('weights') bias = T.dvector('bias') ishape = [(8, 3, 256, 256), (16, 3, 256, 256), (32, 3, 256, 256), (64, 3, 256, 256)] wshape = [(8, 3, 3, 3), (16, 3, 3, 3), (32, 3, 3, 3), (64, 3, 3, 3)] for i, ish in enumerate(ishape): wsh = wshape[i] images_internal = U2IConv(imshp=ish, kshp=wsh)(images) convOut = Conv2D(imshp=ish, kshp=wsh, filter_flip=False)(images_internal, weights, bias) convOut_user = I2U()(convOut) convOutLoss = T.mean(convOut_user) conv_op_di = theano.grad(convOutLoss, images) conv_op_dk = theano.grad(convOutLoss, weights) conv_op_db = theano.grad(convOutLoss, bias) convOutBack = [conv_op_di, conv_op_dk, conv_op_db] ival = numpy.random.rand(*ish).astype(numpy.float64) wval = numpy.random.rand(*wsh).astype(numpy.float64) bval = numpy.random.rand(wsh[0]).astype( numpy.float64) - numpy.random.rand(wsh[0]).astype( numpy.float64) fopt = theano.function(inputs=[images, weights, bias], outputs=convOutBack, mode=mode_with_mkl) new_out = fopt(ival, wval, bval) convOut = conv2d(images, weights, input_shape=ish, filter_shape=wsh, filter_flip=False) convOutLoss = T.mean(convOut + bias.dimshuffle('x', 0, 'x', 'x')) conv_op_di = theano.grad(convOutLoss, images) conv_op_dk = theano.grad(convOutLoss, weights) conv_op_db = theano.grad(convOutLoss, bias) convOutBack = [conv_op_di, conv_op_dk, conv_op_db] fori = theano.function(inputs=[images, weights, bias], outputs=convOutBack, mode=mode_without_mkl) old_out = fori(ival, wval, bval) assert len(fopt.maker.fgraph.toposort()) != len( fori.maker.fgraph.toposort()) assert numpy.allclose(old_out[0], new_out[0]) # assert numpy.allclose(old_out[1], new_out[1]) assert numpy.allclose(old_out[2], new_out[2]) assert new_out[0].dtype == 'float64' assert new_out[2].dtype == 'float64'
def mkl_concatenate_func(*inputs): # _, axis, tensors = inputs axis = inputs[1] tensors = inputs[2:] tensors_internal = [U2IConcatenate()(x) for x in tensors] new_inputs = [axis] + tensors_internal out = Concatenate()(*new_inputs) output = I2U()(out) return output
def mkl_relu_func(*inputs): if len(inputs) == 2: # self, image _, x, = inputs x_internal = U2IRelu()(x) reluOut = Relu()(x_internal) output = I2U()(reluOut) elif len(inputs) == 3: # self, image, slope _, x, slope, = inputs x_internal = U2IRelu(slope=slope)(x) reluOut = Relu(slope=slope)(x_internal) output = I2U()(reluOut) else: raise ValueError( "incorrect inputs list, should be 2 ~ 3 parameters!") return output
def test_lrn_float64(self): old_floatX = theano.config.floatX theano.config.floatX = 'float64' x = tensor.dtensor4('x') x_internal = U2ILRN()(x) z_internal = mkl_lrn.LRN()(x_internal) z = I2U()(z_internal) f = theano.function([x], z, mode=mode_with_mkl) imval = numpy.random.rand(4, 2, 4, 4).astype(theano.config.floatX) f(imval) assert f(imval).dtype == 'float64' theano.config.floatX = old_floatX
def test_bn_value(self): X = T.ftensor4('x') Scale = T.vector('scale') Shift = T.vector('shift') x_internal = U2IBatchNormalization(eps=1e-5)(X) z_bn = mkl_bn.BatchNormalization(eps=1e-5, bias=1, term=1)(x_internal, Scale, Shift) z_out = I2U()(z_bn) z_sum = T.sum(z_out) z_grad = T.grad(z_sum, [X]) fgrad = theano.function([X, Scale, Shift], z_grad, mode=with_mkl) ival = numpy.random.rand(64, 5, 128, 128).astype(numpy.float32) sval = numpy.random.rand(5).astype(numpy.float32) tval = numpy.random.rand(5).astype(numpy.float32) fgrad(ival, sval, tval)
def test_lrn_grad_float32(self): old_floatX = theano.config.floatX theano.config.floatX = 'float32' x = tensor.ftensor4('x') x_internal = U2ILRN()(x) z_internal = mkl_lrn.LRN()(x_internal) z = I2U()(z_internal) z_sum = tensor.sum(z) g = tensor.grad(z_sum, [x]) f = theano.function([x], g, mode=mode_with_mkl) imval = numpy.random.rand(4, 2, 4, 4).astype(theano.config.floatX) f(imval) assert f(imval)[0].dtype == 'float32' theano.config.floatX = old_floatX
def local_pool_mkl(node): if not mkl_available(): return if not isinstance(node.op, pool.Pool): return if node.inputs[0].type.ndim != 4: return mkl_ver = theano.sandbox.mkl.mkl_version() mkl_pool_modes = ['min', 'max', 'average_exc_pad'] mkl_ignore_border = [False] if isinstance(mkl_ver, integer_types) and (mkl_ver >= 20170206): mkl_pool_modes.append('average_inc_pad') mkl_ignore_border.append(True) if node.op.mode not in mkl_pool_modes: return if node.op.ignore_border not in mkl_ignore_border: return x, ws, stride, pad = node.inputs if stride is None: stride = ws try: x_internal = U2IPool(ignore_border=node.op.ignore_border, mode=node.op.mode)(x, ws, stride, pad) poolOut = mkl_pool.Pool(ignore_border=node.op.ignore_border, mode=node.op.mode)(x_internal, ws, stride, pad) z_user = I2U()(poolOut) rval = z_user return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_ConvGroup_mkl(node): if not mkl_available(): return if not isinstance(node.op, mkl_conv.AbstractConvGroup): return # image if node.inputs[0].type.ndim != 4: return # weight if node.inputs[1].type.ndim not in [4, 5]: return try: assert len(node.inputs) in [2, 3] if len(node.inputs) == 2: image, weight, = node.inputs bias = None else: image, weight, bias, = node.inputs image_internal = U2IConv(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_dilation=node.op.filter_dilation)(image) conv_out = mkl_conv.Conv2D(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_flip=node.op.filter_flip, filter_dilation=node.op.filter_dilation)(image_internal, weight, bias) conv_out = I2U()(conv_out) rval = conv_out return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_Conv2D_mkl(node): if not mkl_available(): return if not isinstance(node.op, AbstractConv2d): return if node.op.filter_dilation != (1, 1): return if node.inputs[1].type.ndim != 4 and node.inputs[1].type.ndim != 5: return if None in node.op.kshp: return if None in node.op.imshp: return try: image, weight = node.inputs image_internal = U2IConv(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_dilation=node.op.filter_dilation)(image) convOut = mkl_conv.Conv2D(imshp=node.op.imshp, kshp=node.op.kshp, border_mode=node.op.border_mode, subsample=node.op.subsample, filter_flip=node.op.filter_flip, filter_dilation=node.op.filter_dilation)(image_internal, weight) z_user = I2U()(convOut) reval = z_user return [reval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def test_mkl_lrn_value(self): shape = [(2, 15, 3, 4), (256, 256, 27, 27)] # NCHW n = 5 k = 2 alpha = 0.0001 beta = 0.75 x = tensor.dtensor4('x') x_internal = U2ILRN()(x) z_internal = mkl_lrn.LRN(alpha, beta, k, n)(x_internal) z = I2U()(z_internal) fz = theano.function([x], z, mode=mode_with_mkl) # for shape[0] input_data = numpy.random.rand(*shape[0]).astype(theano.config.floatX) t = self.ground_truth_normalizer(input_data, k=k, n=n, alpha=alpha, beta=beta) assert (fz(input_data)).shape == t.shape assert numpy.allclose(fz(input_data), t)
def local_concatenate_mkl(node): if not mkl_available(): return if not isinstance(node.op, Join): return if node.inputs[1].type.ndim != 4: return try: axis, tensors = node.inputs[0], node.inputs[1:] tensors_internal = [U2IConcatenate()(x) for x in tensors] new_inputs = [axis] + tensors_internal concatenateOut = mkl_concatenate.Concatenate()(*new_inputs) z_user = I2U()(concatenateOut) rval = z_user return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_concatenate_mkl(node): if not mkl_available(): return if not isinstance(node.op, Join): return if node.inputs[1].type.ndim != 4: return try: axis, tensors = node.inputs[0], node.inputs[1:] if not isinstance(axis, integer_types): try: axis = int(get_scalar_constant_value(axis)) except NotScalarConstantError: return if isinstance(axis, integer_types): # MKL Concatenate only supports axis=1 if axis != 1: return tensors_internal = [U2IConcatenate()(x) for x in tensors] new_inputs = [axis] + tensors_internal concatenateOut = mkl_concatenate.Concatenate()(*new_inputs) z_user = I2U()(concatenateOut) rval = z_user return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_relu_mkl(node): if not mkl_available(): return if not isinstance(node.op, mkl_relu.AbstractRelu): return if node.inputs[0].type.ndim != 4: return x, = node.inputs try: x_internal = U2IRelu(slope=node.op.slope)(x) reluOut = mkl_relu.Relu(slope=node.op.slope)(x_internal) z_user = I2U()(reluOut) rval = z_user return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_bn_mkl(node): if not mkl_available(): return if not isinstance(node.op, mkl_bn.AbstractBatchNormalization): return if node.inputs[0].type.ndim != 4: return try: x, scale, shift, = node.inputs[0:3] x_u2i = U2IBatchNormalization(eps=node.op.eps)(x) bn_out = mkl_bn.BatchNormalization(eps=node.op.eps, bias=node.op.bias, term=node.op.term)(x_u2i, scale, shift) z_i2u = I2U()(bn_out) rval = z_i2u return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def apply(self, fgraph): if not mkl_available(): return did_something = True while did_something: did_something = False topo = fgraph.toposort() for node in topo: if (node in fgraph.apply_nodes) and isinstance(node.op, AbstractConv2d): inp = node.inputs out = node.outputs imshp = getattr(node.op, 'imshp', None) kshp = getattr(node.op, 'kshp', None) border_mode = getattr(node.op, 'border_mode', 'valid') subsample = getattr(node.op, 'subsample', (1, 1)) filter_flip = getattr(node.op, 'filter_flip', False) filter_dilation = getattr(node.op, 'filter_dilation', (1, 1)) # Get Elemwise node if (len(out) == 1 and (not out[0] in fgraph.outputs) and isinstance(out[0].clients[0][0].op, tensor.Elemwise) and isinstance(out[0].clients[0][0].op.scalar_op, scalar.Add)): if len(out[0].clients[0][0].inputs) == 2: if out[0].clients[0][0].inputs[0] is out[0]: bias = out[0].clients[0][0].inputs[1] else: bias = out[0].clients[0][0].inputs[0] # Get DimShuffle node bias_owner = bias.owner if (bias_owner is None): try: inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_dilation=filter_dilation)(inp[0]) out_0 = mkl_conv.Conv2D(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(image=inp_0, weight=inp[1], bias=bias) fgraph.repalce_validate(out[0].clients[0][0].outputs[0], out_0, 'ReplaceConvBias') did_something = True except Exception as e: raise elif isinstance(bias_owner.op, tensor.DimShuffle) and (bias_owner.inputs[0].owner is None): try: inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_dilation=filter_dilation)(inp[0]) out_0 = mkl_conv.Conv2D(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(image=inp_0, weight=inp[1], bias=bias_owner.inputs[0]) out_1 = I2U()(out_0) fgraph.replace_validate(out[0].clients[0][0].outputs[0], out_1, 'ReplaceConvBias') did_something = True except Exception as e: raise else: pass elif (node in fgraph.apply_nodes) and isinstance(node.op, AbstractConv2d_gradWeights): inp = node.inputs # 0-image, 1-gz, 2-shape out = node.outputs imshp = getattr(node.op, 'imshp', None) kshp = getattr(node.op, 'kshp', None) border_mode = getattr(node.op, 'border_mode', 'valid') subsample = getattr(node.op, 'subsample', (1, 1)) filter_flip = getattr(node.op, 'filter_flip', False) filter_dilation = getattr(node.op, 'filter_dilation', (1, 1)) assert len(inp) == 3 and len(out) == 1 for i, c in enumerate(inp[0].clients): if hasattr(c[0], 'op') and isinstance(c[0].op, U2IConv) and self._check_attributes_(c[0], node): for cc in c[0].outputs[0].clients: if isinstance(cc[0].op, mkl_conv.Conv2D) and len(cc[0].inputs) == 3: weight, bias = cc[0].inputs[1:3] try: inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_dilation=filter_dilation)(inp[0]) conv_fw = mkl_conv.Conv2D(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(inp_0, weight, bias) gz = I2UGrad()(conv_fw, inp[1]) out_0, out_1 = mkl_conv.ConvGradWeights(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(image=inp_0, weight=weight, gradz=gz, bias=bias) # Get BiasGrad oriBiasGrad = None # BiasGrad in original function graph gz_node = inp[1].owner for i, o in enumerate(gz_node.outputs): if inp[1] is o and len(o.clients) >= 2: oriBiasGrad = self._check_grad_bias_(gz_node, i) fgraph.replace_validate(out[0], out_0, 'ReplaceConvBias') if oriBiasGrad: fgraph.replace_validate(oriBiasGrad, out_1, 'ReplaceConvBias') did_something = True except Exception as e: raise elif (node in fgraph.apply_nodes) and isinstance(node.op, AbstractConv2d_gradInputs): inp = node.inputs # 0-weight, 1-gz, 2-shape out = node.outputs imshp = getattr(node.op, 'imshp', None) kshp = getattr(node.op, 'kshp', None) border_mode = getattr(node.op, 'border_mode', 'valid') subsample = getattr(node.op, 'subsample', (1, 1)) filter_flip = getattr(node.op, 'filter_flip', False) filter_dilation = getattr(node.op, 'filter_dilation', (1, 1)) assert len(inp) == 3 and len(out) == 1 list_Conv2D = [c[0] for c in inp[0].clients if (hasattr(c[0], 'op') and isinstance(c[0].op, mkl_conv.Conv2D) and len(c[0].inputs) == 3 and self._check_attributes_(c[0], node))] if 3 > len(list_Conv2D) > 0: x = list_Conv2D[0].inputs[0].owner.inputs[0] bias = list_Conv2D[0].inputs[2] inp_0 = list_Conv2D[0].inputs[0] try: inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_dilation=filter_dilation)(x) conv_fw = mkl_conv.Conv2D(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(inp_0, inp[0], bias) gz = I2UGrad()(conv_fw, inp[1]) out_0 = mkl_conv.ConvGradInputs(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(inp_0, inp[0], gz) inp_grad = U2IGrad()(x, out_0) fgraph.replace_validate(out[0], inp_grad, 'ReplaceConvBias') did_something = True except Exception as e: raise e else: pass
def apply(self, fgraph): if not mkl_available(): return def getElemwiseInput(node, inputs, coeffs, co): inp = inputs coe = coeffs # Elemwise_add if ((isinstance(node.op, tensor.Elemwise) and isinstance(node.op.scalar_op, scalar.Add))): for i in node.inputs: n = i.owner if (n is not None and isinstance(n.op, tensor.Elemwise) and isinstance(n.op.scalar_op, scalar.Add)): getElemwiseInput(n, inp, coe, co) else: inp.append(i) coe.append(co) # Elemwise_mul: This case has been deleted. # We just process Elemwise{Add} here to avoid disturbing the Elemwise{Complesite} fusion. else: raise TypeError('The OP of the inputs node should be an instance of Elemwise{Add}') did_something = True while did_something: did_something = False topo = fgraph.toposort() topo.reverse() for node in topo: if node in fgraph.apply_nodes: if (isinstance(node.op, tensor.Elemwise) and isinstance(node.op.scalar_op, scalar.Add)): out = node.outputs inputs = [] coeffs = [] co = 1.0 # For now, all the coeffs are 1.0 since Elemwise{Mul} is not processed getElemwiseInput(node, inputs, coeffs, co) inp_len = len(inputs) assert len(inputs) == len(coeffs) if inp_len >= 2: # print(inputs) # print(coeffs) # Check all inputs are from I2U and U2IGrad if all([(i.owner and isinstance(i.owner.op, (I2U, U2IGrad))) for i in inputs]): try: inputs_t = [] for i in inputs: inputs_t.append(U2IElemwiseSum(inp_num=inp_len, coeff=coeffs)(i)) out_t = mkl_elemwise.ElemwiseSum(inp_num=inp_len, coeff=coeffs)(*inputs_t) new_out = I2U()(out_t) fgraph.replace_validate(out[0], new_out, 'ReplaceElemwise') did_something = True except Exception as e: raise e else: pass