예제 #1
0
def local_ConvGradWeights_mkl(node):

    if not mkl_available():
        return

    if not isinstance(node.op, AbstractConv2d_gradWeights):
        return

    if node.inputs[1].type.ndim != 4 and node.inputs[1].type.ndim != 5:
        return

    if node.op.filter_dilation != (1, 1):
        return

    if None in node.op.kshp:
        return

    if None in node.op.imshp:
        return

    try:
        image, gz, zshp = node.inputs
        weight = node.inputs[2].owner.inputs[0].owner.inputs[0]
        image_internal = U2IConv(imshp=node.op.imshp,
                                 kshp=node.op.kshp,
                                 subsample=node.op.subsample,
                                 border_mode=node.op.border_mode,
                                 filter_dilation=node.op.filter_dilation)(image)
        convOut = mkl_conv.Conv2D(imshp=node.op.imshp,
                                  kshp=node.op.kshp,
                                  border_mode=node.op.border_mode,
                                  subsample=node.op.subsample,
                                  filter_flip=node.op.filter_flip,
                                  filter_dilation=node.op.filter_dilation)(image_internal, weight)
        gz_internal = I2UGrad()(convOut, gz)
        gradWeight = mkl_conv.ConvGradWeights(border_mode=node.op.border_mode,
                                              subsample=node.op.subsample,
                                              imshp=node.op.imshp,
                                              kshp=node.op.kshp)(image_internal, weight, gz_internal)
        rval = gradWeight
        return [rval]
    except Exception as e:
        msg = ('Failed to apply local opt to Op %s. '
               'Exception message: %s\n') % (node.op, str(e))
        _logger.warning(msg)
        return
예제 #2
0
def local_ConvGroup_mkl(node):

    if not mkl_available():
        return

    if not isinstance(node.op, mkl_conv.AbstractConvGroup):
        return

    # image
    if node.inputs[0].type.ndim != 4:
        return

    # weight
    if node.inputs[1].type.ndim not in [4, 5]:
        return

    try:
        assert len(node.inputs) in [2, 3]
        if len(node.inputs) == 2:
            image, weight, = node.inputs
            bias = None
        else:
            image, weight, bias, = node.inputs
        image_internal = U2IConv(imshp=node.op.imshp,
                                 kshp=node.op.kshp,
                                 subsample=node.op.subsample,
                                 border_mode=node.op.border_mode,
                                 filter_dilation=node.op.filter_dilation)(image)
        conv_out = mkl_conv.Conv2D(imshp=node.op.imshp,
                                   kshp=node.op.kshp,
                                   subsample=node.op.subsample,
                                   border_mode=node.op.border_mode,
                                   filter_flip=node.op.filter_flip,
                                   filter_dilation=node.op.filter_dilation)(image_internal, weight, bias)
        conv_out = I2U()(conv_out)
        rval = conv_out
        return [rval]
    except Exception as e:
        msg = ('Failed to apply local opt to Op %s. '
               'Exception message: %s\n') % (node.op, str(e))
        _logger.warning(msg)
        return
예제 #3
0
def local_Conv2D_mkl(node):
    if not mkl_available():
        return

    if not isinstance(node.op, AbstractConv2d):
        return

    if node.op.filter_dilation != (1, 1):
        return

    if node.inputs[1].type.ndim != 4 and node.inputs[1].type.ndim != 5:
        return

    if None in node.op.kshp:
        return

    if None in node.op.imshp:
        return

    try:
        image, weight = node.inputs
        image_internal = U2IConv(imshp=node.op.imshp,
                                 kshp=node.op.kshp,
                                 subsample=node.op.subsample,
                                 border_mode=node.op.border_mode,
                                 filter_dilation=node.op.filter_dilation)(image)
        convOut = mkl_conv.Conv2D(imshp=node.op.imshp,
                                  kshp=node.op.kshp,
                                  border_mode=node.op.border_mode,
                                  subsample=node.op.subsample,
                                  filter_flip=node.op.filter_flip,
                                  filter_dilation=node.op.filter_dilation)(image_internal, weight)
        z_user = I2U()(convOut)
        reval = z_user
        return [reval]
    except Exception as e:
        msg = ('Failed to apply local opt to Op %s. '
               'Exception message: %s\n') % (node.op, str(e))
        _logger.warning(msg)
        return
예제 #4
0
def local_ConvGroupGrad_mkl(node):

    if not mkl_available():
        return

    if not isinstance(node.op, mkl_conv.AbstractConvGroupGrad):
        return

    # image
    if node.inputs[0].type.ndim != 4:
        return

    # weight
    if node.inputs[2].type.ndim not in [4, 5]:
        return

    try:
        assert len(node.inputs) in [3, 4]
        if len(node.inputs) == 3:
            image, gz, weight = node.inputs
            bias = None
        else:
            image, gz, weight, bias = node.inputs
        image_internal = U2IConv(imshp=node.op.imshp,
                                 kshp=node.op.kshp,
                                 subsample=node.op.subsample,
                                 border_mode=node.op.border_mode,
                                 filter_dilation=node.op.filter_dilation)(image)
        conv_out = mkl_conv.Conv2D(imshp=node.op.imshp,
                                   kshp=node.op.kshp,
                                   subsample=node.op.subsample,
                                   border_mode=node.op.border_mode,
                                   filter_flip=node.op.filter_flip,
                                   filter_dilation=node.op.filter_dilation)(image_internal, weight, bias)
        gz_internal = I2UGrad()(conv_out, gz)
        grad_image = mkl_conv.ConvGradInputs(imshp=node.op.imshp,
                                             kshp=node.op.kshp,
                                             subsample=node.op.subsample,
                                             border_mode=node.op.border_mode,
                                             filter_flip=node.op.filter_flip,
                                             filter_dilation=node.op.filter_dilation)(image_internal, weight, gz_internal)
        grad_image = U2IGrad()(image, grad_image)

        grad_out = mkl_conv.ConvGradWeights(imshp=node.op.imshp,
                                            kshp=node.op.kshp,
                                            subsample=node.op.subsample,
                                            border_mode=node.op.border_mode,
                                            filter_flip=node.op.filter_flip,
                                            filter_dilation=node.op.filter_dilation)(image_internal, weight, gz_internal, bias)
        if isinstance(grad_out, (list, tuple)):
            grad_weight, grad_bias, = grad_out
        else:
            grad_weight = grad_out

        if len(node.outputs) == 3:
            assert len(grad_out) == 2
            rval = [grad_image, grad_weight, grad_bias]
        else:
            rval = [grad_image, grad_weight]
        return rval
    except Exception as e:
        msg = ('Failed to apply local opt to Op %s. '
               'Exception message: %s\n') % (node.op, str(e))
        _logger.warning(msg)
        return
예제 #5
0
    def apply(self, fgraph):
        if not mkl_available():
            return

        did_something = True
        while did_something:
            did_something = False
            topo = fgraph.toposort()
            for node in topo:
                if (node in fgraph.apply_nodes) and isinstance(node.op, AbstractConv2d):
                    inp = node.inputs
                    out = node.outputs
                    imshp = getattr(node.op, 'imshp', None)
                    kshp = getattr(node.op, 'kshp', None)
                    border_mode = getattr(node.op, 'border_mode', 'valid')
                    subsample = getattr(node.op, 'subsample', (1, 1))
                    filter_flip = getattr(node.op, 'filter_flip', False)
                    filter_dilation = getattr(node.op, 'filter_dilation', (1, 1))

                    # Get Elemwise node
                    if (len(out) == 1 and (not out[0] in fgraph.outputs) and
                            isinstance(out[0].clients[0][0].op, tensor.Elemwise) and
                            isinstance(out[0].clients[0][0].op.scalar_op, scalar.Add)):
                        if len(out[0].clients[0][0].inputs) == 2:
                            if out[0].clients[0][0].inputs[0] is out[0]:
                                bias = out[0].clients[0][0].inputs[1]
                            else:
                                bias = out[0].clients[0][0].inputs[0]
                            # Get DimShuffle node
                            bias_owner = bias.owner
                            if (bias_owner is None):
                                try:
                                    inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample,
                                                    filter_dilation=filter_dilation)(inp[0])
                                    out_0 = mkl_conv.Conv2D(imshp=imshp,
                                                            kshp=kshp,
                                                            border_mode=border_mode,
                                                            subsample=subsample,
                                                            filter_flip=filter_flip,
                                                            filter_dilation=filter_dilation)(image=inp_0, weight=inp[1], bias=bias)
                                    fgraph.repalce_validate(out[0].clients[0][0].outputs[0],
                                                            out_0,
                                                            'ReplaceConvBias')
                                    did_something = True
                                except Exception as e:
                                    raise
                            elif isinstance(bias_owner.op, tensor.DimShuffle) and (bias_owner.inputs[0].owner is None):
                                try:
                                    inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample,
                                                    filter_dilation=filter_dilation)(inp[0])
                                    out_0 = mkl_conv.Conv2D(imshp=imshp,
                                                            kshp=kshp,
                                                            border_mode=border_mode,
                                                            subsample=subsample,
                                                            filter_flip=filter_flip,
                                                            filter_dilation=filter_dilation)(image=inp_0, weight=inp[1], bias=bias_owner.inputs[0])
                                    out_1 = I2U()(out_0)
                                    fgraph.replace_validate(out[0].clients[0][0].outputs[0],
                                                            out_1,
                                                            'ReplaceConvBias')
                                    did_something = True
                                except Exception as e:
                                    raise
                            else:
                                pass
                elif (node in fgraph.apply_nodes) and isinstance(node.op, AbstractConv2d_gradWeights):
                    inp = node.inputs  # 0-image, 1-gz, 2-shape
                    out = node.outputs
                    imshp = getattr(node.op, 'imshp', None)
                    kshp = getattr(node.op, 'kshp', None)
                    border_mode = getattr(node.op, 'border_mode', 'valid')
                    subsample = getattr(node.op, 'subsample', (1, 1))
                    filter_flip = getattr(node.op, 'filter_flip', False)
                    filter_dilation = getattr(node.op, 'filter_dilation', (1, 1))

                    assert len(inp) == 3 and len(out) == 1
                    for i, c in enumerate(inp[0].clients):
                        if hasattr(c[0], 'op') and isinstance(c[0].op, U2IConv) and self._check_attributes_(c[0], node):
                            for cc in c[0].outputs[0].clients:
                                if isinstance(cc[0].op, mkl_conv.Conv2D) and len(cc[0].inputs) == 3:
                                    weight, bias = cc[0].inputs[1:3]
                                    try:
                                        inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample,
                                                        filter_dilation=filter_dilation)(inp[0])
                                        conv_fw = mkl_conv.Conv2D(imshp=imshp,
                                                                  kshp=kshp,
                                                                  border_mode=border_mode,
                                                                  subsample=subsample,
                                                                  filter_flip=filter_flip,
                                                                  filter_dilation=filter_dilation)(inp_0, weight, bias)
                                        gz = I2UGrad()(conv_fw, inp[1])
                                        out_0, out_1 = mkl_conv.ConvGradWeights(imshp=imshp,
                                                                                kshp=kshp,
                                                                                border_mode=border_mode,
                                                                                subsample=subsample,
                                                                                filter_flip=filter_flip,
                                                                                filter_dilation=filter_dilation)(image=inp_0, weight=weight, gradz=gz, bias=bias)
                                        # Get BiasGrad
                                        oriBiasGrad = None  # BiasGrad in original function graph
                                        gz_node = inp[1].owner
                                        for i, o in enumerate(gz_node.outputs):
                                            if inp[1] is o and len(o.clients) >= 2:
                                                oriBiasGrad = self._check_grad_bias_(gz_node, i)

                                        fgraph.replace_validate(out[0], out_0, 'ReplaceConvBias')
                                        if oriBiasGrad:
                                            fgraph.replace_validate(oriBiasGrad, out_1, 'ReplaceConvBias')
                                        did_something = True
                                    except Exception as e:
                                        raise
                elif (node in fgraph.apply_nodes) and isinstance(node.op, AbstractConv2d_gradInputs):
                    inp = node.inputs  # 0-weight, 1-gz, 2-shape
                    out = node.outputs
                    imshp = getattr(node.op, 'imshp', None)
                    kshp = getattr(node.op, 'kshp', None)
                    border_mode = getattr(node.op, 'border_mode', 'valid')
                    subsample = getattr(node.op, 'subsample', (1, 1))
                    filter_flip = getattr(node.op, 'filter_flip', False)
                    filter_dilation = getattr(node.op, 'filter_dilation', (1, 1))

                    assert len(inp) == 3 and len(out) == 1
                    list_Conv2D = [c[0] for c in inp[0].clients if (hasattr(c[0], 'op') and
                                                                    isinstance(c[0].op, mkl_conv.Conv2D) and
                                                                    len(c[0].inputs) == 3 and
                                                                    self._check_attributes_(c[0], node))]
                    if 3 > len(list_Conv2D) > 0:
                        x = list_Conv2D[0].inputs[0].owner.inputs[0]
                        bias = list_Conv2D[0].inputs[2]
                        inp_0 = list_Conv2D[0].inputs[0]
                        try:
                            inp_0 = U2IConv(imshp=imshp, kshp=kshp, border_mode=border_mode, subsample=subsample,
                                            filter_dilation=filter_dilation)(x)
                            conv_fw = mkl_conv.Conv2D(imshp=imshp,
                                                      kshp=kshp,
                                                      border_mode=border_mode,
                                                      subsample=subsample,
                                                      filter_flip=filter_flip,
                                                      filter_dilation=filter_dilation)(inp_0, inp[0], bias)
                            gz = I2UGrad()(conv_fw, inp[1])
                            out_0 = mkl_conv.ConvGradInputs(imshp=imshp,
                                                            kshp=kshp,
                                                            border_mode=border_mode,
                                                            subsample=subsample,
                                                            filter_flip=filter_flip,
                                                            filter_dilation=filter_dilation)(inp_0, inp[0], gz)
                            inp_grad = U2IGrad()(x, out_0)
                            fgraph.replace_validate(out[0], inp_grad, 'ReplaceConvBias')
                            did_something = True
                        except Exception as e:
                            raise e
                else:
                    pass