def local_lrnGrad_mkl(node): if not mkl_available(): return if not isinstance(node.op, mkl_lrn.AbstractLRNGrad): return if node.inputs[0].type.ndim != 4: return try: x, gz, = node.inputs x_u2i = U2ILRN(alpha=node.op.alpha, beta=node.op.beta, k=node.op.k, n=node.op.n)(x) lrnOut = mkl_lrn.LRN(alpha=node.op.alpha, beta=node.op.beta, k=node.op.k, n=node.op.n)(x_u2i) gz_u2i = I2UGrad()(lrnOut, gz) lrnGradOut = mkl_lrn.LRNGrad(alpha=node.op.alpha, beta=node.op.beta, k=node.op.k, n=node.op.n)(x_u2i, gz_u2i) gx_i2u = U2IGrad()(x, lrnGradOut) rval = gx_i2u return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_bnGrad_mkl(node): if not mkl_available(): return if not isinstance(node.op, mkl_bn.AbstractBatchNormalizationGrad): return if node.inputs[0].type.ndim != 4: return try: x, gz, scale, shift, = node.inputs x_u2i = U2IBatchNormalization(eps=node.op.eps)(x) bn_out = mkl_bn.BatchNormalization(eps=node.op.eps, bias=node.op.bias, term=node.op.term)(x_u2i, scale, shift) gz_u2i = I2UGrad()(bn_out, gz) bn_GradOut = mkl_bn.BatchNormalizationGrad(eps=node.op.eps, bias=node.op.bias, term=node.op.term)(x_u2i, gz_u2i, scale, shift) gx_i2u = U2IGrad()(x, bn_GradOut[0]) rval = [gx_i2u, bn_GradOut[1], bn_GradOut[2]] return rval except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_reluGrad_mkl(node): if not mkl_available(): return if not isinstance(node.op, mkl_relu.AbstractReluGrad): return if node.inputs[0].type.ndim != 4: return x, gz = node.inputs try: x_internal = U2IRelu(slope=node.op.slope)(x) reluOut = mkl_relu.Relu(slope=node.op.slope)(x_internal) gz_internal = I2UGrad()(reluOut, gz) reluGradOut = mkl_relu.ReluGrad(slope=node.op.slope)(x_internal, gz_internal) gx_user = U2IGrad()(x, reluGradOut) rval = gx_user return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_concatenateGrad_mkl(node): if not mkl_available(): return if not isinstance(node.op, Split): return if node.inputs[0].type.ndim != 4: return try: gz, axis, splits, = node.inputs if not isinstance(axis, integer_types): try: axis = int(get_scalar_constant_value(axis)) except NotScalarConstantError: return if isinstance(axis, integer_types): # MKL Concatenate only supports axis=1 if axis != 1: return # Retrieve the inputs to Join op # inp_0 inp_1 inp # | | | # Splits <- MakeVector <- [Subtensor...] <- Shape <- inputs if not isinstance(splits.owner.op, theano.tensor.opt.MakeVector): return tensors = [] for inp_0 in splits.owner.inputs: if not isinstance(inp_0.owner.op, theano.tensor.subtensor.Subtensor): return inp_1 = inp_0.owner.inputs[0] if not isinstance(inp_1.owner.op, theano.compile.ops.Shape): return inp = inp_1.owner.inputs[0] tensors.append(inp) tensors_internal = [U2IConcatenate()(x) for x in tensors] new_inputs = [axis] + tensors_internal z_internal = mkl_concatenate.Concatenate()(*new_inputs) gz_internal = I2UGrad()(z_internal, gz) concatenateGradOut = mkl_concatenate.ConcatenateGrad()(gz_internal, axis, *tensors_internal) gx_user = [U2IGrad()(_x, _gz) for _x, _gz in zip(tensors, concatenateGradOut)] rval = gx_user return rval except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_poolGrad_mkl(node): if not mkl_available(): return if node.inputs[0].type.ndim != 4: return mkl_ver = theano.sandbox.mkl.mkl_version() mkl_pool_modes = ['min', 'max', 'average_exc_pad'] mkl_ignore_border = [False] if isinstance(mkl_ver, integer_types) and (mkl_ver >= 20170206): mkl_pool_modes.append('average_inc_pad') mkl_ignore_border.append(True) if node.op.mode not in mkl_pool_modes: return if node.op.ignore_border not in mkl_ignore_border: return if isinstance(node.op, pool.MaxPoolGrad): x, maxout, gz, ws, stride, pad = node.inputs elif isinstance(node.op, pool.AveragePoolGrad): x, gz, ws, stride, pad = node.inputs else: # Other pool mode is not supported return if stride is None: stride = ws try: x_internal = U2IPool(ignore_border=node.op.ignore_border, mode=node.op.mode)(x, ws, stride, pad) poolOut = mkl_pool.Pool(ignore_border=node.op.ignore_border, mode=node.op.mode)(x_internal, ws, stride, pad) gz_internal = I2UGrad()(poolOut, gz) poolGradOut = mkl_pool.PoolGrad(ignore_border=node.op.ignore_border, mode=node.op.mode)(x_internal, gz_internal, ws, stride, pad) gx_user = U2IGrad()(x, poolGradOut) rval = gx_user return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_ConvGradInputs_mkl(node): if not mkl_available(): return if not isinstance(node.op, AbstractConv2d_gradInputs): return if node.inputs[1].type.ndim != 4 and node.inputs[1].type.ndim != 5: return if node.op.filter_dilation != (1, 1): return if None in node.op.kshp: return if None in node.op.imshp: return try: weight, gz, zshp = node.inputs image = node.inputs[2].owner.inputs[0].owner.inputs[0] image_internal = U2IConv(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_dilation=node.op.filter_dilation)(image) convOut = mkl_conv.Conv2D(imshp=node.op.imshp, kshp=node.op.kshp, border_mode=node.op.border_mode, subsample=node.op.subsample, filter_flip=node.op.filter_flip, filter_dilation=node.op.filter_dilation)(image_internal, weight) gz_internal = I2UGrad()(convOut, gz) gradImage = mkl_conv.ConvGradInputs(border_mode=node.op.border_mode, subsample=node.op.subsample, imshp=node.op.imshp, kshp=node.op.kshp)(image_internal, weight, gz_internal) gradImage_user = U2IGrad()(image, gradImage) rval = gradImage_user return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_ConvGroupGrad_mkl(node): if not mkl_available(): return if not isinstance(node.op, mkl_conv.AbstractConvGroupGrad): return # image if node.inputs[0].type.ndim != 4: return # weight if node.inputs[2].type.ndim not in [4, 5]: return try: assert len(node.inputs) in [3, 4] if len(node.inputs) == 3: image, gz, weight = node.inputs bias = None else: image, gz, weight, bias = node.inputs image_internal = U2IConv(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_dilation=node.op.filter_dilation)(image) conv_out = mkl_conv.Conv2D(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_flip=node.op.filter_flip, filter_dilation=node.op.filter_dilation)(image_internal, weight, bias) gz_internal = I2UGrad()(conv_out, gz) grad_image = mkl_conv.ConvGradInputs(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_flip=node.op.filter_flip, filter_dilation=node.op.filter_dilation)(image_internal, weight, gz_internal) grad_image = U2IGrad()(image, grad_image) grad_out = mkl_conv.ConvGradWeights(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_flip=node.op.filter_flip, filter_dilation=node.op.filter_dilation)(image_internal, weight, gz_internal, bias) if isinstance(grad_out, (list, tuple)): grad_weight, grad_bias, = grad_out else: grad_weight = grad_out if len(node.outputs) == 3: assert len(grad_out) == 2 rval = [grad_image, grad_weight, grad_bias] else: rval = [grad_image, grad_weight] return rval except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def apply(self, fgraph): if not mkl_available(): return did_something = True while did_something: did_something = False topo = fgraph.toposort() for node in topo: if (node in fgraph.apply_nodes) and isinstance(node.op, AbstractConv2d): inp = node.inputs out = node.outputs imshp = getattr(node.op, 'imshp', None) kshp = getattr(node.op, 'kshp', None) border_mode = getattr(node.op, 'border_mode', 'valid') subsample = getattr(node.op, 'subsample', (1, 1)) filter_flip = getattr(node.op, 'filter_flip', False) filter_dilation = getattr(node.op, 'filter_dilation', (1, 1)) # Get Elemwise node if (len(out) == 1 and (not out[0] in fgraph.outputs) and isinstance(out[0].clients[0][0].op, tensor.Elemwise) and isinstance(out[0].clients[0][0].op.scalar_op, scalar.Add)): if len(out[0].clients[0][0].inputs) == 2: if out[0].clients[0][0].inputs[0] is out[0]: bias = out[0].clients[0][0].inputs[1] else: bias = out[0].clients[0][0].inputs[0] # Get DimShuffle node bias_owner = bias.owner if (bias_owner is None): try: inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_dilation=filter_dilation)(inp[0]) out_0 = mkl_conv.Conv2D(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(image=inp_0, weight=inp[1], bias=bias) fgraph.repalce_validate(out[0].clients[0][0].outputs[0], out_0, 'ReplaceConvBias') did_something = True except Exception as e: raise elif isinstance(bias_owner.op, tensor.DimShuffle) and (bias_owner.inputs[0].owner is None): try: inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_dilation=filter_dilation)(inp[0]) out_0 = mkl_conv.Conv2D(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(image=inp_0, weight=inp[1], bias=bias_owner.inputs[0]) out_1 = I2U()(out_0) fgraph.replace_validate(out[0].clients[0][0].outputs[0], out_1, 'ReplaceConvBias') did_something = True except Exception as e: raise else: pass elif (node in fgraph.apply_nodes) and isinstance(node.op, AbstractConv2d_gradWeights): inp = node.inputs # 0-image, 1-gz, 2-shape out = node.outputs imshp = getattr(node.op, 'imshp', None) kshp = getattr(node.op, 'kshp', None) border_mode = getattr(node.op, 'border_mode', 'valid') subsample = getattr(node.op, 'subsample', (1, 1)) filter_flip = getattr(node.op, 'filter_flip', False) filter_dilation = getattr(node.op, 'filter_dilation', (1, 1)) assert len(inp) == 3 and len(out) == 1 for i, c in enumerate(inp[0].clients): if hasattr(c[0], 'op') and isinstance(c[0].op, U2IConv) and self._check_attributes_(c[0], node): for cc in c[0].outputs[0].clients: if isinstance(cc[0].op, mkl_conv.Conv2D) and len(cc[0].inputs) == 3: weight, bias = cc[0].inputs[1:3] try: inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_dilation=filter_dilation)(inp[0]) conv_fw = mkl_conv.Conv2D(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(inp_0, weight, bias) gz = I2UGrad()(conv_fw, inp[1]) out_0, out_1 = mkl_conv.ConvGradWeights(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(image=inp_0, weight=weight, gradz=gz, bias=bias) # Get BiasGrad oriBiasGrad = None # BiasGrad in original function graph gz_node = inp[1].owner for i, o in enumerate(gz_node.outputs): if inp[1] is o and len(o.clients) >= 2: oriBiasGrad = self._check_grad_bias_(gz_node, i) fgraph.replace_validate(out[0], out_0, 'ReplaceConvBias') if oriBiasGrad: fgraph.replace_validate(oriBiasGrad, out_1, 'ReplaceConvBias') did_something = True except Exception as e: raise elif (node in fgraph.apply_nodes) and isinstance(node.op, AbstractConv2d_gradInputs): inp = node.inputs # 0-weight, 1-gz, 2-shape out = node.outputs imshp = getattr(node.op, 'imshp', None) kshp = getattr(node.op, 'kshp', None) border_mode = getattr(node.op, 'border_mode', 'valid') subsample = getattr(node.op, 'subsample', (1, 1)) filter_flip = getattr(node.op, 'filter_flip', False) filter_dilation = getattr(node.op, 'filter_dilation', (1, 1)) assert len(inp) == 3 and len(out) == 1 list_Conv2D = [c[0] for c in inp[0].clients if (hasattr(c[0], 'op') and isinstance(c[0].op, mkl_conv.Conv2D) and len(c[0].inputs) == 3 and self._check_attributes_(c[0], node))] if 3 > len(list_Conv2D) > 0: x = list_Conv2D[0].inputs[0].owner.inputs[0] bias = list_Conv2D[0].inputs[2] inp_0 = list_Conv2D[0].inputs[0] try: inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_dilation=filter_dilation)(x) conv_fw = mkl_conv.Conv2D(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(inp_0, inp[0], bias) gz = I2UGrad()(conv_fw, inp[1]) out_0 = mkl_conv.ConvGradInputs(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(inp_0, inp[0], gz) inp_grad = U2IGrad()(x, out_0) fgraph.replace_validate(out[0], inp_grad, 'ReplaceConvBias') did_something = True except Exception as e: raise e else: pass