def local_ConvGradWeights_mkl(node): if not mkl_available(): return if not isinstance(node.op, AbstractConv2d_gradWeights): return if node.inputs[1].type.ndim != 4 and node.inputs[1].type.ndim != 5: return if node.op.filter_dilation != (1, 1): return if None in node.op.kshp: return if None in node.op.imshp: return try: image, gz, zshp = node.inputs weight = node.inputs[2].owner.inputs[0].owner.inputs[0] image_internal = U2IConv(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_dilation=node.op.filter_dilation)(image) convOut = mkl_conv.Conv2D(imshp=node.op.imshp, kshp=node.op.kshp, border_mode=node.op.border_mode, subsample=node.op.subsample, filter_flip=node.op.filter_flip, filter_dilation=node.op.filter_dilation)(image_internal, weight) gz_internal = I2UGrad()(convOut, gz) gradWeight = mkl_conv.ConvGradWeights(border_mode=node.op.border_mode, subsample=node.op.subsample, imshp=node.op.imshp, kshp=node.op.kshp)(image_internal, weight, gz_internal) rval = gradWeight return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_ConvGroup_mkl(node): if not mkl_available(): return if not isinstance(node.op, mkl_conv.AbstractConvGroup): return # image if node.inputs[0].type.ndim != 4: return # weight if node.inputs[1].type.ndim not in [4, 5]: return try: assert len(node.inputs) in [2, 3] if len(node.inputs) == 2: image, weight, = node.inputs bias = None else: image, weight, bias, = node.inputs image_internal = U2IConv(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_dilation=node.op.filter_dilation)(image) conv_out = mkl_conv.Conv2D(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_flip=node.op.filter_flip, filter_dilation=node.op.filter_dilation)(image_internal, weight, bias) conv_out = I2U()(conv_out) rval = conv_out return [rval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_Conv2D_mkl(node): if not mkl_available(): return if not isinstance(node.op, AbstractConv2d): return if node.op.filter_dilation != (1, 1): return if node.inputs[1].type.ndim != 4 and node.inputs[1].type.ndim != 5: return if None in node.op.kshp: return if None in node.op.imshp: return try: image, weight = node.inputs image_internal = U2IConv(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_dilation=node.op.filter_dilation)(image) convOut = mkl_conv.Conv2D(imshp=node.op.imshp, kshp=node.op.kshp, border_mode=node.op.border_mode, subsample=node.op.subsample, filter_flip=node.op.filter_flip, filter_dilation=node.op.filter_dilation)(image_internal, weight) z_user = I2U()(convOut) reval = z_user return [reval] except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def local_ConvGroupGrad_mkl(node): if not mkl_available(): return if not isinstance(node.op, mkl_conv.AbstractConvGroupGrad): return # image if node.inputs[0].type.ndim != 4: return # weight if node.inputs[2].type.ndim not in [4, 5]: return try: assert len(node.inputs) in [3, 4] if len(node.inputs) == 3: image, gz, weight = node.inputs bias = None else: image, gz, weight, bias = node.inputs image_internal = U2IConv(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_dilation=node.op.filter_dilation)(image) conv_out = mkl_conv.Conv2D(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_flip=node.op.filter_flip, filter_dilation=node.op.filter_dilation)(image_internal, weight, bias) gz_internal = I2UGrad()(conv_out, gz) grad_image = mkl_conv.ConvGradInputs(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_flip=node.op.filter_flip, filter_dilation=node.op.filter_dilation)(image_internal, weight, gz_internal) grad_image = U2IGrad()(image, grad_image) grad_out = mkl_conv.ConvGradWeights(imshp=node.op.imshp, kshp=node.op.kshp, subsample=node.op.subsample, border_mode=node.op.border_mode, filter_flip=node.op.filter_flip, filter_dilation=node.op.filter_dilation)(image_internal, weight, gz_internal, bias) if isinstance(grad_out, (list, tuple)): grad_weight, grad_bias, = grad_out else: grad_weight = grad_out if len(node.outputs) == 3: assert len(grad_out) == 2 rval = [grad_image, grad_weight, grad_bias] else: rval = [grad_image, grad_weight] return rval except Exception as e: msg = ('Failed to apply local opt to Op %s. ' 'Exception message: %s\n') % (node.op, str(e)) _logger.warning(msg) return
def apply(self, fgraph): if not mkl_available(): return did_something = True while did_something: did_something = False topo = fgraph.toposort() for node in topo: if (node in fgraph.apply_nodes) and isinstance(node.op, AbstractConv2d): inp = node.inputs out = node.outputs imshp = getattr(node.op, 'imshp', None) kshp = getattr(node.op, 'kshp', None) border_mode = getattr(node.op, 'border_mode', 'valid') subsample = getattr(node.op, 'subsample', (1, 1)) filter_flip = getattr(node.op, 'filter_flip', False) filter_dilation = getattr(node.op, 'filter_dilation', (1, 1)) # Get Elemwise node if (len(out) == 1 and (not out[0] in fgraph.outputs) and isinstance(out[0].clients[0][0].op, tensor.Elemwise) and isinstance(out[0].clients[0][0].op.scalar_op, scalar.Add)): if len(out[0].clients[0][0].inputs) == 2: if out[0].clients[0][0].inputs[0] is out[0]: bias = out[0].clients[0][0].inputs[1] else: bias = out[0].clients[0][0].inputs[0] # Get DimShuffle node bias_owner = bias.owner if (bias_owner is None): try: inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_dilation=filter_dilation)(inp[0]) out_0 = mkl_conv.Conv2D(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(image=inp_0, weight=inp[1], bias=bias) fgraph.repalce_validate(out[0].clients[0][0].outputs[0], out_0, 'ReplaceConvBias') did_something = True except Exception as e: raise elif isinstance(bias_owner.op, tensor.DimShuffle) and (bias_owner.inputs[0].owner is None): try: inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_dilation=filter_dilation)(inp[0]) out_0 = mkl_conv.Conv2D(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(image=inp_0, weight=inp[1], bias=bias_owner.inputs[0]) out_1 = I2U()(out_0) fgraph.replace_validate(out[0].clients[0][0].outputs[0], out_1, 'ReplaceConvBias') did_something = True except Exception as e: raise else: pass elif (node in fgraph.apply_nodes) and isinstance(node.op, AbstractConv2d_gradWeights): inp = node.inputs # 0-image, 1-gz, 2-shape out = node.outputs imshp = getattr(node.op, 'imshp', None) kshp = getattr(node.op, 'kshp', None) border_mode = getattr(node.op, 'border_mode', 'valid') subsample = getattr(node.op, 'subsample', (1, 1)) filter_flip = getattr(node.op, 'filter_flip', False) filter_dilation = getattr(node.op, 'filter_dilation', (1, 1)) assert len(inp) == 3 and len(out) == 1 for i, c in enumerate(inp[0].clients): if hasattr(c[0], 'op') and isinstance(c[0].op, U2IConv) and self._check_attributes_(c[0], node): for cc in c[0].outputs[0].clients: if isinstance(cc[0].op, mkl_conv.Conv2D) and len(cc[0].inputs) == 3: weight, bias = cc[0].inputs[1:3] try: inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_dilation=filter_dilation)(inp[0]) conv_fw = mkl_conv.Conv2D(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(inp_0, weight, bias) gz = I2UGrad()(conv_fw, inp[1]) out_0, out_1 = mkl_conv.ConvGradWeights(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(image=inp_0, weight=weight, gradz=gz, bias=bias) # Get BiasGrad oriBiasGrad = None # BiasGrad in original function graph gz_node = inp[1].owner for i, o in enumerate(gz_node.outputs): if inp[1] is o and len(o.clients) >= 2: oriBiasGrad = self._check_grad_bias_(gz_node, i) fgraph.replace_validate(out[0], out_0, 'ReplaceConvBias') if oriBiasGrad: fgraph.replace_validate(oriBiasGrad, out_1, 'ReplaceConvBias') did_something = True except Exception as e: raise elif (node in fgraph.apply_nodes) and isinstance(node.op, AbstractConv2d_gradInputs): inp = node.inputs # 0-weight, 1-gz, 2-shape out = node.outputs imshp = getattr(node.op, 'imshp', None) kshp = getattr(node.op, 'kshp', None) border_mode = getattr(node.op, 'border_mode', 'valid') subsample = getattr(node.op, 'subsample', (1, 1)) filter_flip = getattr(node.op, 'filter_flip', False) filter_dilation = getattr(node.op, 'filter_dilation', (1, 1)) assert len(inp) == 3 and len(out) == 1 list_Conv2D = [c[0] for c in inp[0].clients if (hasattr(c[0], 'op') and isinstance(c[0].op, mkl_conv.Conv2D) and len(c[0].inputs) == 3 and self._check_attributes_(c[0], node))] if 3 > len(list_Conv2D) > 0: x = list_Conv2D[0].inputs[0].owner.inputs[0] bias = list_Conv2D[0].inputs[2] inp_0 = list_Conv2D[0].inputs[0] try: inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_dilation=filter_dilation)(x) conv_fw = mkl_conv.Conv2D(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(inp_0, inp[0], bias) gz = I2UGrad()(conv_fw, inp[1]) out_0 = mkl_conv.ConvGradInputs(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample, filter_flip=filter_flip, filter_dilation=filter_dilation)(inp_0, inp[0], gz) inp_grad = U2IGrad()(x, out_0) fgraph.replace_validate(out[0], inp_grad, 'ReplaceConvBias') did_something = True except Exception as e: raise e else: pass