示例#1
0
    def test_params_type_with_enums(self):
        # Test that we fail if we create a params type with common enum names inside different enum types.
        try:
            ParamsType(enum1=EnumList("A", "B", "C"),
                       enum2=EnumList("A", "B", "F"))
        except AttributeError:
            pass
        else:
            raise Exception(
                "ParamsType should fail with common enum names inside different enum types."
            )

        # Test that we fail if we create a params type with common names in both aliases and constants.
        try:
            ParamsType(
                enum1=EnumList(("A", "a"), ("B", "b")),
                enum2=EnumList(("ONE", "a"), ("TWO", "two")),
            )
        except AttributeError:
            ParamsType(
                enum1=EnumList(("A", "a"), ("B", "b")),
                enum2=EnumList(("ONE", "one"), ("TWO", "two")),
            )
        else:
            raise Exception(
                "ParamsType should fail when there are aliases with same names as some constants."
            )

        # Test that we can access enum values through wrapper directly.
        w = ParamsType(
            enum1=EnumList("A", ("B", "beta"), "C"),
            enum2=EnumList(("D", "delta"), "E", "F"),
        )
        assert w.A == 0 and w.B == 1 and w.C == 2
        assert w.D == 0 and w.E == 1 and w.F == 2
        # Test constants access through aliases.
        assert w.enum_from_alias("beta") == w.B
        assert w.enum_from_alias("delta") == w.D
        assert (w.enum_from_alias("C") == w.C
                )  # C is not an alias, so it should return a constant named C.
        # Test that other regular wrapper attributes are still available.
        assert len(w.fields) == len(w.types) == w.length
        assert w.name
示例#2
0
    def test_params_type_with_enums(self):
        # Test that we fail if we create a params type with common enum names inside different enum types.
        try:
            ParamsType(enum1=EnumList('A', 'B', 'C'),
                       enum2=EnumList('A', 'B', 'F'))
        except AttributeError:
            pass
        else:
            raise Exception(
                'ParamsType should fail with common enum names inside different enum types.'
            )

        # Test that we fail if we create a params type with common names in both aliases and constants.
        try:
            ParamsType(enum1=EnumList(('A', 'a'), ('B', 'b')),
                       enum2=EnumList(('ONE', 'a'), ('TWO', 'two')))
        except AttributeError:
            ParamsType(enum1=EnumList(('A', 'a'), ('B', 'b')),
                       enum2=EnumList(('ONE', 'one'), ('TWO', 'two')))
        else:
            raise Exception(
                'ParamsType should fail when there are aliases with same names as some constants.'
            )

        # Test that we can access enum values through wrapper directly.
        w = ParamsType(enum1=EnumList('A', ('B', 'beta'), 'C'),
                       enum2=EnumList(('D', 'delta'), 'E', 'F'))
        assert w.A == 0 and w.B == 1 and w.C == 2
        assert w.D == 0 and w.E == 1 and w.F == 2
        # Test constants access through aliases.
        assert w.enum_from_alias('beta') == w.B
        assert w.enum_from_alias('delta') == w.D
        assert w.enum_from_alias(
            'C'
        ) == w.C  # C is not an alias, so it should return a constant named C.
        # Test that other regular wrapper attributes are still available.
        assert len(w.fields) == len(w.types) == w.length
        assert w.name
示例#3
0
class BaseCorrMM(gof.OpenMPOp):
    """
    Base class for `CorrMM`, `CorrMM_gradWeights` and
    `CorrMM_gradInputs`. Cannot be used directly.

    Every sub-class must define internal attribute ``_direction`` out of __init__().
    ``_direction`` must take one of following values:

     - "forward" to correlate bottom with weights and store results in top.
     - "backprop weights" to do a valid convolution of bottom with top
       (swapping the first two dimensions) and store results in weights.
     - "backprop inputs" to do a full convolution of top with weights
       (swapping the first two dimensions) and store results in bottom.

    Parameters
    ----------
    border_mode : {'valid', 'full', 'half'}
        Additionally, the padding size could be directly specified by an integer,
        a pair of integers, or two pairs of integers.
    subsample
        Perform subsampling of the output (default: (1, 1)).
    filter_dilation
        Perform dilated correlation (default: (1,1))
    num_groups
        Perform grouped convolutions (default: 1)
    unshared
        Perform unshared correlation (default: False)
    """

    check_broadcast = False
    __props__ = (
        "border_mode",
        "subsample",
        "filter_dilation",
        "num_groups",
        "unshared",
    )

    _direction = None

    params_type = ParamsType(
        direction=EnumList(
            ("DIRECTION_FORWARD", "forward"),  # 0
            ("DIRECTION_BACKPROP_WEIGHTS", "backprop weights"),  # 1
            ("DIRECTION_BACKPROP_INPUTS", "backprop inputs"),
        ),  # 2
        dH=int64,
        dW=int64,
        dilH=int64,
        dilW=int64,
        padH_l=int64,
        padH_r=int64,
        padW_l=int64,
        padW_r=int64,
        num_groups=int64,
        unshared=int8,
    )

    def __init__(
        self,
        border_mode="valid",
        subsample=(1, 1),
        filter_dilation=(1, 1),
        num_groups=1,
        unshared=False,
        openmp=None,
    ):
        super(BaseCorrMM, self).__init__(openmp=openmp)
        if isinstance(border_mode, integer_types):
            if border_mode < 0:
                raise ValueError(
                    "invalid border_mode {}, which must be a "
                    "non-negative integer".format(border_mode)
                )
            border_mode = ((border_mode, border_mode),) * 2
        elif isinstance(border_mode, tuple):
            if len(border_mode) != 2:
                raise ValueError(
                    "invalid border_mode {} which must be a "
                    "tuple of length 2".format(border_mode)
                )
            border = ()
            for mode in border_mode:
                if isinstance(mode, tuple) and len(mode) == 2 and min(mode) >= 0:
                    border += ((int(mode[0]), int(mode[1])),)
                elif mode >= 0:
                    border += ((int(mode), int(mode)),)
                else:
                    raise ValueError(
                        "invalid border mode {}. The tuple can only contain "
                        "integers or tuples of length 2".format(border_mode)
                    )
            border_mode = border
        elif border_mode not in ("valid", "full", "half"):
            raise ValueError(
                "invalid border_mode {}, which must be either "
                '"valid", "full", "half", an integer or a tuple '
                "of two integers or a pair of integers".format(border_mode)
            )
        self.border_mode = border_mode
        if len(subsample) != 2:
            raise ValueError("subsample must have two elements")
        if len(filter_dilation) != 2:
            raise ValueError("filter_dilation must have two elements")
        self.subsample = tuple(subsample)
        self.filter_dilation = tuple(filter_dilation)
        self.unshared = unshared

        if not theano.config.blas.ldflags:
            # Theano will use a NumPy C implementation of [sd]gemm_ instead.
            self.blas_type = ""
        else:
            if "openblas" in theano.config.blas.ldflags:
                self.blas_type = "openblas"
            elif "mkl" in theano.config.blas.ldflags:
                self.blas_type = "mkl"
            else:
                self.blas_type = ""

        if self._direction not in ["forward", "backprop weights", "backprop inputs"]:
            raise ValueError(
                "_direction must be one of 'forward', "
                "'backprop weights', 'backprop inputs'"
            )
        if num_groups < 1:
            raise ValueError("Number of groups should be greater than 0")
        self.num_groups = num_groups

    @property
    def pad(self):
        if self.border_mode == "half":
            return ((-1, -1),) * 2
        elif self.border_mode == "full":
            return ((-2, -2),) * 2
        elif isinstance(self.border_mode, tuple):
            return self.border_mode
        else:
            assert self.border_mode == "valid"
            return ((0, 0),) * 2

    # Direction should be converted to real enum value,
    # as it is compared to integer later in c_code_helper().
    direction = property(lambda self: self.params_type.enum_from_alias(self._direction))

    dH = property(lambda self: self.subsample[0])
    dW = property(lambda self: self.subsample[1])

    dilH = property(lambda self: self.filter_dilation[0])
    dilW = property(lambda self: self.filter_dilation[1])

    padH_l = property(lambda self: self.pad[0][0])
    padH_r = property(lambda self: self.pad[0][1])
    padW_l = property(lambda self: self.pad[1][0])
    padW_r = property(lambda self: self.pad[1][1])

    def __str__(self):
        return "%s{%s, %s, %s, %s %s}" % (
            self.__class__.__name__,
            self.border_mode,
            str(self.subsample),
            str(self.filter_dilation),
            str(self.num_groups),
            str(self.unshared),
        )

    @staticmethod
    def as_common_dtype(in1, in2):
        """
        Upcast input variables if necessary.
        """
        dtype = theano.scalar.upcast(in1.dtype, in2.dtype)
        return in1.astype(dtype), in2.astype(dtype)

    def __setstate__(self, d):
        self.__dict__.update(d)
        if not hasattr(self, "num_groups"):
            self.num_groups = 1

    def c_support_code(self):
        ccodes = blas_headers.blas_header_text()
        if self.blas_type == "openblas":
            ccodes += blas_headers.openblas_threads_text()
        elif self.blas_type == "mkl":
            ccodes += blas_headers.mkl_threads_text()
        return ccodes

    def c_libraries(self):
        return ldflags()

    def c_compile_args(self):
        compile_args = ldflags(libs=False, flags=True)
        compile_args += super(BaseCorrMM, self).c_compile_args()
        return compile_args

    def c_lib_dirs(self):
        return ldflags(libs=False, libs_dir=True)

    def c_header_dirs(self):
        return ldflags(libs=False, include_dir=True)

    def c_headers(self):
        headers = ["<stdio.h>"]
        headers += super(BaseCorrMM, self).c_headers()
        return headers

    def c_code_cache_version(self):
        # raise this whenever modifying any of the support_code_files
        return (10, self.openmp, blas_header_version())

    def c_support_code_apply(self, node, nodename):
        # REMEMBER TO RAISE c_code_cache_version when changing any of
        # these files
        sub = {}
        dtype = str(node.__dict__["inputs"][0].dtype)
        assert dtype in ("float32", "float64")
        if dtype == "float32":
            sub["gemm"] = "sgemm_"
            sub["gemv"] = "sgemv_"
            sub["float_type"] = "npy_float"
            sub["float_typenum"] = "NPY_FLOAT"
            sub["n_bytes"] = 4
            sub["c_float_type"] = "float"
        else:
            sub["gemm"] = "dgemm_"
            sub["gemv"] = "dgemv_"
            sub["float_type"] = "npy_double"
            sub["float_typenum"] = "NPY_DOUBLE"
            sub["n_bytes"] = 8
            sub["c_float_type"] = "double"

        if self.openmp:
            sub["omp_flags"] = "#pragma omp parallel for schedule(static)"
            sub["omp_get_max_threads"] = "omp_get_max_threads()"
            sub["omp_get_thread_num"] = "omp_get_thread_num()"

            if self.blas_type == "openblas":
                sub["blas_set_num_threads"] = "openblas_set_num_threads"
                sub["blas_get_num_threads"] = "openblas_get_num_threads()"
            elif self.blas_type == "mkl":
                sub["blas_set_num_threads"] = "mkl_set_num_threads"
                sub["blas_get_num_threads"] = "mkl_get_max_threads()"
            else:
                sub["blas_set_num_threads"] = ""
                sub["blas_get_num_threads"] = "0"
        else:
            sub["omp_flags"] = ""
            sub["omp_get_max_threads"] = "1"
            sub["omp_get_thread_num"] = "0"
            sub["blas_set_num_threads"] = ""
            sub["blas_get_num_threads"] = "0"

        files = [os.path.join("c_code", "corr_gemm.c")]
        codes = [
            open(os.path.join(os.path.split(__file__)[0], f)).read() for f in files
        ]
        final_code = ""
        for code in codes:
            final_code += code
        return final_code % sub

    def c_code_helper(self, bottom, weights, top, sub, height=None, width=None):
        """
        This generates the C code for CorrMM (direction="forward"),
        CorrMM_gradWeights (direction="backprop weights"), and
        CorrMM_gradInputs (direction="backprop inputs").
        Depending on the direction, one of bottom, weights, top will
        receive the output, while the other two serve as inputs.

        :param bottom: Variable name of the input images in the forward pass,
            or the gradient of the input images in backprop wrt. inputs
        :param weights: Variable name of the filters in the forward pass,
            or the gradient of the filters in backprop wrt. weights
        :param top: Variable name of the output images / feature maps in the
            forward pass, or the gradient of the outputs in the backprop passes
        :param sub: Dictionary of substitutions useable to help generating the
            C code.
        :param height: If self.subsample[0] != 1, a variable giving the height
            of the filters for direction="backprop weights" or the height of
            the input images for direction="backprop inputs".

            If self.border_mode == 'half', a variable giving the height of the
            filters for direction="backprop weights".  Ignored otherwise.
        :param width: If self.subsample[1] != 1, a variable giving the width
            of the filters for direction="backprop weights" or the width of the
            input images for direction="backprop inputs".

            If self.border_mode == 'half', a variable giving the width of the
            filters for direction="backprop weights".  Ignored otherwise.
        """

        # When subsampling, we cannot unambiguously infer the height and width
        # of bottom and weights from top, so we require them to be given.
        # Similarly, when border_mode="half", we cannot infer the weight size.
        if height:
            height = "(*(npy_int64 *)(PyArray_DATA(%s)))" % height
        else:
            if ((self.direction != 0) and (self.dH != 1)) or (
                (self.direction == 1) and (self.padH_l == -1 or self.padH_r == -1)
            ):
                raise ValueError(
                    "height must be given for backprop with vertical sampling or border_mode='half'"
                )
            height = "-1"
        if width:
            width = "(*(npy_int64 *)(PyArray_DATA(%s)))" % width
        else:
            if ((self.direction != 0) and (self.dW != 1)) or (
                (self.direction == 1) and (self.padW_l == -1 or self.padW_r == -1)
            ):
                raise ValueError(
                    "width must be given for backprop with horizontal sampling or border_mode='half'"
                )
            width = "-1"

        return """
    // Mandatory args
    int direction = %(params)s->direction;  // forward, bprop weights, bprop inputs

    // Optional args
    int dH = %(params)s->dH;
    int dW = %(params)s->dW;
    int dilH = %(params)s->dilH;
    int dilW = %(params)s->dilW;
    int padH_l = %(params)s->padH_l;
    int padH_r = %(params)s->padH_r;
    int padW_l = %(params)s->padW_l;
    int padW_r = %(params)s->padW_r;
    int numgroups = %(params)s->num_groups;
    int unshared = %(params)s->unshared;

    PyArrayObject * bottom = %(bottom)s;
    PyArrayObject * weights = %(weights)s;
    PyArrayObject * top = %(top)s;
    PyArrayObject * out2 = NULL;
    PyArrayObject **out = NULL;

    switch(%(params)s->direction) {
        case DIRECTION_FORWARD:
            out = &%(top)s;
            break;
        case DIRECTION_BACKPROP_WEIGHTS:
            out = &%(weights)s;
            break;
        case DIRECTION_BACKPROP_INPUTS:
            out = &%(bottom)s;
            break;
        default:
            PyErr_SetString(PyExc_ValueError, "CPU CorrMM: Invalid direction.");
            {%(fail)s}
            break;
    }

    int wdim, odim;
    wdim = unshared ? 6 : 4;
    odim = 4; //Can be set to 6 later for unshared backprop wrt weights

    // Obtain or infer kernel width and height
    // (we need to know it early to be able to handle auto-padding)
    int kH, kW, dil_kH, dil_kW;
    if (direction != 1) {
        // weight is an input variable, we can just read its shape
        kH = PyArray_DIMS(weights)[wdim-2];
        kW = PyArray_DIMS(weights)[wdim-1];
    }
    else {
        if (%(height)s != -1) {
            // kernel height is specified (perhaps vertical subsampling or half padding)
            kH = %(height)s;
        }
        else if (padH_l == -2 || padH_r == -2) {
            // vertical full padding, we can infer the kernel height
            kH = (2 - PyArray_DIMS(bottom)[2] + (PyArray_DIMS(top)[2] - 1) * dH - 1)/ dilH + 1;
        }
        else {
            // explicit padding, we can infer the kernel height
            kH = (PyArray_DIMS(bottom)[2] + padH_l + padH_r - (PyArray_DIMS(top)[2] - 1) * dH - 1) / dilH +1;
        }
        if (%(width)s != -1) {
            // kernel width is specified (perhaps horizontal subsampling or half padding)
            kW = %(width)s;
        }
        else if (padW_l == -2 || padW_r == -2) {
            kW = (2 - PyArray_DIMS(bottom)[3] + (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
        }
        else {
            kW = (PyArray_DIMS(bottom)[3] + padW_l + padW_r - (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
        }
    }

    // Implicit dilated kernel size
    dil_kH = (kH - 1) * dilH + 1;
    dil_kW = (kW - 1) * dilW + 1;

    // Auto-padding if requested
    if (padH_l == -1 || padH_r == -1) {  // vertical half padding
        padH_l = padH_r = dil_kH / 2;
    }
    else if (padH_l == -2 || padH_r == -2) {  // vertical full padding
        padH_l = padH_r = dil_kH - 1;
    }
    else if (padH_l < -2 || padH_r < -2) {
        PyErr_SetString(PyExc_ValueError, "BaseCorrMM: padH_l and padH_r must be >= -2");
        %(fail)s
    }
    if (padW_l == -1 || padW_r == -1) {  // horizontal half padding
        padW_l = padW_r = dil_kW / 2;
    }
    else if (padW_l == -2 || padW_r == -2) {  // horizontal full padding
        padW_l = padW_r = dil_kW - 1;
    }
    else if (padW_l < -2 || padW_r < -2) {
        PyErr_SetString(PyExc_ValueError, "BaseCorrMM: padW_l and padW_r must be >= -2");
        %(fail)s
    }

    // Infer output shape
    npy_intp out_dim[6];
    out_dim[4] = out_dim[5] = 0; //Only used for unshared backprop wrt weights
    switch(direction) {
    case 0:  // forward pass
        // output is top: (batchsize, num_filters, height, width)
        // height and width: top = (bottom + pad_l + pad_r - ((weight-1)*dil + 1)) / sample + 1
        out_dim[0] = (npy_intp)PyArray_DIMS(bottom)[0];
        out_dim[1] = (npy_intp)PyArray_DIMS(weights)[0];
        out_dim[2] = (npy_intp)((PyArray_DIMS(bottom)[2] + padH_l + padH_r - ((PyArray_DIMS(weights)[wdim-2]-1)*dilH + 1)) / dH + 1);
        out_dim[3] = (npy_intp)((PyArray_DIMS(bottom)[3] + padW_l + padW_r - ((PyArray_DIMS(weights)[wdim-1]-1)*dilW + 1)) / dW + 1);
        if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
        {
            if (unshared) {
                PyErr_Format(PyExc_ValueError,
                             "CorrMM: impossible output shape\\n"
                             "  bottom shape: %%ld x %%ld x %%ld x %%ld\\n"
                             "  weights shape: %%ld x %%ld x %%ld x %%ld x %%ld x %%ld\\n"
                             "  top shape: %%ld x %%ld x %%ld x %%ld\\n",
                             (long int)PyArray_DIMS(bottom)[0], (long int)PyArray_DIMS(bottom)[1],
                             (long int)PyArray_DIMS(bottom)[2], (long int)PyArray_DIMS(bottom)[3],
                             (long int)PyArray_DIMS(weights)[0], (long int)PyArray_DIMS(weights)[1],
                             (long int)PyArray_DIMS(weights)[2], (long int)PyArray_DIMS(weights)[3],
                             (long int)PyArray_DIMS(weights)[4], (long int)PyArray_DIMS(weights)[5],
                             (long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
                             (long int)out_dim[3]);
            }
            else {
                PyErr_Format(PyExc_ValueError,
                             "CorrMM: impossible output shape\\n"
                             "  bottom shape: %%ld x %%ld x %%ld x %%ld\\n"
                             "  weights shape: %%ld x %%ld x %%ld x %%ld\\n"
                             "  top shape: %%ld x %%ld x %%ld x %%ld\\n",
                             (long int)PyArray_DIMS(bottom)[0], (long int)PyArray_DIMS(bottom)[1],
                             (long int)PyArray_DIMS(bottom)[2], (long int)PyArray_DIMS(bottom)[3],
                             (long int)PyArray_DIMS(weights)[0], (long int)PyArray_DIMS(weights)[1],
                             (long int)PyArray_DIMS(weights)[2], (long int)PyArray_DIMS(weights)[3],
                             (long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
                             (long int)out_dim[3]);
            }
            %(fail)s
        }
        break;
    case 1:  // backprop wrt. weights
        // output is weights: (num_filters, num_channels, height, width)
        // height and width: weights = (bottom + pad_l + pad_r - (top - 1) * sample - 1) / dil + 1
        out_dim[0] = (npy_intp)PyArray_DIMS(top)[1];
        if (unshared){
            odim = 6;
            out_dim[1] = (npy_intp)PyArray_DIMS(top)[2];
            out_dim[2] = (npy_intp)PyArray_DIMS(top)[3];
        }
        out_dim[wdim-3] = (npy_intp)PyArray_DIMS(bottom)[1] / numgroups;
        out_dim[wdim-2] = (npy_intp)kH;  // already inferred further above
        out_dim[wdim-1] = (npy_intp)kW;  // how convenient
        if (unshared) {
            if (out_dim[0] < 0 || out_dim[1] <= 0 || out_dim[2] <= 0 || out_dim[3] < 0
                    || out_dim[4] <= 0 || out_dim[5] <= 0){
                PyErr_Format(PyExc_ValueError,
                             "CorrMM backprop wrt. weights: impossible output shape\\n"
                             "  bottom shape: %%ld x %%ld x %%ld x %%ld\\n"
                             "  weights shape: %%ld x %%ld x %%ld x %%ld x %%ld x %%ld\\n"
                             "  top shape: %%ld x %%ld x %%ld x %%ld\\n",
                             (long int)PyArray_DIMS(bottom)[0], (long int)PyArray_DIMS(bottom)[1],
                             (long int)PyArray_DIMS(bottom)[2], (long int)PyArray_DIMS(bottom)[3],
                             (long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
                             (long int)out_dim[3], (long int)out_dim[4], (long int)out_dim[5],
                             (long int)PyArray_DIMS(top)[0], (long int)PyArray_DIMS(top)[1],
                             (long int)PyArray_DIMS(top)[2], (long int)PyArray_DIMS(top)[3]);
                %(fail)s
            }
        }
        else {
            if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
            {
                PyErr_Format(PyExc_ValueError,
                             "CorrMM backprop wrt. weights: impossible output shape\\n"
                             "  bottom shape: %%ld x %%ld x %%ld x %%ld\\n"
                             "  weights shape: %%ld x %%ld x %%ld x %%ld\\n"
                             "  top shape: %%ld x %%ld x %%ld x %%ld\\n",
                             (long int)PyArray_DIMS(bottom)[0], (long int)PyArray_DIMS(bottom)[1],
                             (long int)PyArray_DIMS(bottom)[2], (long int)PyArray_DIMS(bottom)[3],
                             (long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
                             (long int)out_dim[3],
                             (long int)PyArray_DIMS(top)[0], (long int)PyArray_DIMS(top)[1],
                             (long int)PyArray_DIMS(top)[2], (long int)PyArray_DIMS(top)[3]);
                %(fail)s
            }
        }
        break;
    case 2:  // backprop wrt. inputs
        // output is bottom: (batchsize, num_channels, height, width)
        // height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad
        out_dim[0] = (npy_intp)PyArray_DIMS(top)[0];
        out_dim[1] = (npy_intp)PyArray_DIMS(weights)[wdim-3] * numgroups;
        out_dim[2] = (npy_intp)((%(height)s != -1) ? %(height)s : (PyArray_DIMS(top)[2] - 1) * dH + (PyArray_DIMS(weights)[wdim-2]-1)*dilH + 1 - padH_l - padH_r);
        out_dim[3] = (npy_intp)((%(width)s != -1) ? %(width)s : (PyArray_DIMS(top)[3] - 1) * dW + (PyArray_DIMS(weights)[wdim-1]-1)*dilW + 1 - padW_l - padW_r);
        if (unshared) {
            if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
            {
                PyErr_Format(PyExc_ValueError,
                             "CorrMM backprop wrt. inputs: impossible output shape\\n"
                             "  bottom shape: %%ld x %%ld x %%ld x %%ld\\n"
                             "  weights shape: %%ld x %%ld x %%ld x %%ld x %%ld x %%ld\\n"
                             "  top shape: %%ld x %%ld x %%ld x %%ld\\n",
                             (long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
                             (long int)out_dim[3],
                             (long int)PyArray_DIMS(weights)[0], (long int)PyArray_DIMS(weights)[1],
                             (long int)PyArray_DIMS(weights)[2], (long int)PyArray_DIMS(weights)[3],
                             (long int)PyArray_DIMS(weights)[4], (long int)PyArray_DIMS(weights)[5],
                             (long int)PyArray_DIMS(top)[0], (long int)PyArray_DIMS(top)[1],
                             (long int)PyArray_DIMS(top)[2], (long int)PyArray_DIMS(top)[3]);
                %(fail)s
            }
        }
        else {
            if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
            {
                PyErr_Format(PyExc_ValueError,
                             "CorrMM backprop wrt. inputs: impossible output shape\\n"
                             "  bottom shape: %%ld x %%ld x %%ld x %%ld\\n"
                             "  weights shape: %%ld x %%ld x %%ld x %%ld\\n"
                             "  top shape: %%ld x %%ld x %%ld x %%ld\\n",
                             (long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
                             (long int)out_dim[3],
                             (long int)PyArray_DIMS(weights)[0], (long int)PyArray_DIMS(weights)[1],
                             (long int)PyArray_DIMS(weights)[2], (long int)PyArray_DIMS(weights)[3],
                             (long int)PyArray_DIMS(top)[0], (long int)PyArray_DIMS(top)[1],
                             (long int)PyArray_DIMS(top)[2], (long int)PyArray_DIMS(top)[3]);
                %(fail)s
            }
        }
        break;
    default:
        PyErr_SetString(PyExc_ValueError, "BaseCorrMM: direction must be 0, 1, or 2\\n");
        %(fail)s
    }

    // Prepare output array
    int typenum;
    int failure;
    failure = !(*out
           && PyArray_NDIM(*out)==odim
           && PyArray_IS_C_CONTIGUOUS(*out)
           && PyArray_DIMS(*out)[0]==out_dim[0]
           && PyArray_DIMS(*out)[1]==out_dim[1]
           && PyArray_DIMS(*out)[2]==out_dim[2]
           && PyArray_DIMS(*out)[3]==out_dim[3]);
    if (odim == 6){
        failure = failure || !(PyArray_DIMS(*out)[4]==out_dim[4]
                && PyArray_DIMS(*out)[5]==out_dim[5]);
    }
    if ( failure )
    {
        Py_XDECREF(*out);
        if (direction != 1) {
          typenum = PyArray_TYPE(weights);
        }
        else {
          typenum = PyArray_TYPE(bottom);
        }
        //Change to PyArray_ZEROS which is faster than PyArray_EMPTY.
        *out = (PyArrayObject*)PyArray_ZEROS(odim,
                                          out_dim,
                                          typenum,
                                          0);
        if (NULL == *out)
        {
            if (odim == 4) {
                PyErr_Format(PyExc_RuntimeError,
                        "BaseCorrMM: Failed to allocate output of %%lld x %%lld x %%lld x %%lld",
                        (long long)out_dim[0], (long long)out_dim[1], (long long)out_dim[2], (long long)out_dim[3]);
            }
            if (odim == 6) {
                PyErr_Format(PyExc_RuntimeError,
                        "BaseCorrMM: Failed to allocate output of %%lld x %%lld x %%lld x %%lld %%lld %%lld",
                        (long long)out_dim[0], (long long)out_dim[1], (long long)out_dim[2], (long long)out_dim[3],
                        (long long)out_dim[4], (long long)out_dim[5]);
            }
            %(fail)s
        }
    }

    // Call corrMM code
    out2 = corrMM(%(bottom)s, %(weights)s, %(top)s, direction, dH, dW, dilH, dilW,
                padH_l, padH_r, padW_l, padW_r, numgroups, unshared);
    if (out2==NULL){
       %(fail)s
    }
    assert (out2 == *out);

""" % dict(
            bottom=bottom,
            weights=weights,
            top=top,
            height=height,
            width=width,
            fail=sub["fail"],
            params=sub["params"],
        )
示例#4
0
class CumOp(theano.Op):
    # See function cumsum/cumprod for docstring

    __props__ = ("axis", "mode")
    check_input = False
    params_type = ParamsType(c_axis=int_t,
                             mode=EnumList(('MODE_ADD', 'add'),
                                           ('MODE_MUL', 'mul')))

    def __init__(self, axis=None, mode='add'):
        if mode not in ('add', 'mul'):
            raise ValueError('%s: Unknown mode "%s"' %
                             (type(self).__name__, mode))
        self.axis = axis
        self.mode = mode

    c_axis = property(lambda self: np.MAXDIMS
                      if self.axis is None else self.axis)

    def make_node(self, x):
        x = basic.as_tensor_variable(x)
        out_type = x.type()

        if self.axis is None:
            out_type = theano.tensor.vector(dtype=x.dtype)  # Flatten
        elif self.axis >= x.ndim or self.axis < -x.ndim:
            raise ValueError('axis(={0}) out of bounds'.format(self.axis))

        return theano.Apply(self, [x], [out_type])

    def perform(self, node, inputs, output_storage, params):
        x = inputs[0]
        z = output_storage[0]
        z[0] = {
            'add': np.cumsum,
            'mul': np.cumprod
        }[self.mode](x, axis=self.axis)

    def grad(self, inputs, output_gradients):
        x, = inputs
        gi, = output_gradients

        if self.axis is None:
            if self.mode == 'add':
                return [cumsum(gi[::-1])[::-1].reshape(x.shape)]
            elif self.mode == 'mul':
                fx = cumprod(x, axis=self.axis)
                return [cumsum((fx * gi)[::-1])[::-1].reshape(x.shape) / x]
            else:
                raise NotImplementedError(
                    '%s: unknown gradient for mode "%s"' %
                    (type(self).__name__, self.mode))

        reverse_slicing = [slice(None, None, None)] * gi.ndim
        reverse_slicing[self.axis] = slice(None, None, -1)
        reverse_slicing = tuple(reverse_slicing)
        # We need to reverse the gradients along ``self.axis``,
        #  compute cumsum, then reverse again
        if self.mode == 'add':
            return [cumsum(gi[reverse_slicing], self.axis)[reverse_slicing]]
        elif self.mode == 'mul':
            fx = cumprod(x, axis=self.axis)
            return [
                cumsum(
                    (fx * gi)[reverse_slicing], self.axis)[reverse_slicing] / x
            ]
        else:
            raise NotImplementedError('%s: unknown gradient for mode "%s"' %
                                      (type(self).__name__, self.mode))

    def infer_shape(self, node, shapes):
        if self.axis is None:
            return [(tensor.prod(shapes[0]), )]  # Flatten

        return shapes

    def c_code(self, node, name, inames, onames, sub):
        x, = inames
        z, = onames
        axis = self.axis
        fail = sub['fail']
        params = sub['params']

        code = """
                int axis = %(params)s->c_axis;
                if (axis == 0 && PyArray_NDIM(%(x)s) == 1)
                    axis = NPY_MAXDIMS;
                npy_intp shape[1] = { PyArray_SIZE(%(x)s) };
                if(axis == NPY_MAXDIMS && !(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0]))
                {
                    Py_XDECREF(%(z)s);
                    %(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_%(x)s));
                }

                else if(axis != NPY_MAXDIMS && !(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s))))
                {
                    Py_XDECREF(%(z)s);
                    %(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE(%(x)s));
                }

                if (!%(z)s)
                    %(fail)s;
                {

                    PyObject * t = NULL;
                    if(%(params)s->mode == MODE_ADD)
                        t = PyArray_CumSum(
                            %(x)s, axis,
                            PyArray_TYPE(%(x)s), %(z)s);
                    else if(%(params)s->mode == MODE_MUL)
                        t = PyArray_CumProd(
                            %(x)s, axis,
                            PyArray_TYPE(%(x)s), %(z)s);

                    if (!t){
                       %(fail)s;
                    }
                    // Because PyArray_CumSum/CumProd returns a newly created reference on t.
                    Py_XDECREF(t);
                }
            """ % locals()

        return code

    def c_code_cache_version(self):
        return (8, )

    def __str__(self):
        return "%s{%s, %s}" % (self.__class__.__name__, self.axis, self.mode)
示例#5
0
class Images2Neibs(Op):
    """
    Reshapes the input as a 2D tensor where each row is an pooling
    example.

    Parameters
    ----------
    mode : {'valid', 'ignore_borders', 'wrap_centered'}
        - 'valid' :
            Requires an input that is a multiple of the pooling factor
            (in each direction).
        - 'half' :
            Equivalent to 'valid' if we pre-pad with zeros the input on
            each side by (neib_shape[0]//2, neib_shape[1]//2)
        - 'full' :
            Equivalent to 'valid' if we pre-pad with zeros the input on
            each side by (neib_shape[0] - 1, neib_shape[1] - 1)
        - 'ignore_borders' :
            Same as valid, but will ignore the borders if the shape(s)
            of the input is not a multiple of the pooling factor(s).
        - 'wrap_centered' :
            ?? TODO comment

    """

    __props__ = ("mode", )
    BORDER_MODE = EnumList(
        ("MODE_VALID", "valid"),
        ("MODE_HALF", "half"),
        ("MODE_FULL", "full"),
        ("MODE_WRAP_CENTERED", "wrap_centered"),
        ("MODE_IGNORE_BORDERS", "ignore_borders"),
    )
    params_type = BORDER_MODE

    def get_params(self, node):
        return self.mode

    def __init__(self, mode="valid"):
        implemented_modes = self.BORDER_MODE.get_aliases()
        if mode not in implemented_modes:
            raise NotImplementedError(
                f"Only modes {', '.join(implemented_modes)} have been implemented for {type(self).__name__}"
            )
        self.mode = mode

    def __str__(self):
        return self.__class__.__name__ + "{%s}" % self.mode

    def __setstate__(self, d):
        self.__dict__.update(d)
        if not hasattr(self, "mode"):
            self.mode = "valid"

    def make_node(self, ten4, neib_shape, neib_step=None):
        """
        Parameters
        ----------
        ten4 : a list of lists of images
            ten4 is of shape (list 1 dim, list 2 dim, row, col).
        neib_shape
            (r,c) where r is the height of the neighborhood in rows and c is
            the width of the neighborhood in columns.
        neib_step
            (dr,dc) where dr is the number of rows to skip between patch and dc
            is the number of columns. When None, this is the same as neib_shape
            (patch are disjoint).

        Returns
        -------
        matrix
            A 2D matrix, written using the following pattern::

                idx = 0
                for i in range(list 1 dim)
                    for j in range(list 2 dim)
                        for k in <image column coordinates>
                            for l in <image row coordinates>
                                output[idx,:]
                                     = flattened version of ten4[i,j,l:l+r,k:k+c]
                                idx += 1

            .. note:: The op isn't necessarily implemented internally with these
                for loops, they're just the easiest way to describe the output
                pattern.

        """
        ten4 = tt.as_tensor_variable(ten4)
        neib_shape = tt.as_tensor_variable(neib_shape)
        if neib_step is None:
            neib_step = neib_shape
        else:
            neib_step = tt.as_tensor_variable(neib_step)

        assert ten4.ndim == 4
        assert neib_shape.ndim == 1
        assert neib_step.ndim == 1

        return Apply(self, [ten4, neib_shape, neib_step],
                     [tt.matrix(dtype=ten4.type.dtype)])

    def grad(self, inp, grads):
        x, neib_shape, neib_step = inp
        (gz, ) = grads

        if self.mode in ["valid", "ignore_borders"]:
            if (neib_shape is neib_step or neib_shape == neib_step or
                    # Theano Constant == do not compare the data
                    # the equals function do that.
                (hasattr(neib_shape, "equals") and neib_shape.equals(neib_step)
                 )):
                return [
                    neibs2images(gz, neib_shape, x.shape, mode=self.mode),
                    grad_undefined(self, 1, neib_shape),
                    grad_undefined(self, 2, neib_step),
                ]

        if self.mode in ["valid"]:
            # Iterate over neighborhood positions, summing contributions.
            def pos2map(pidx, pgz, prior_result, neib_shape, neib_step):
                """
                Helper function that adds gradient contribution from a single
                neighborhood position i,j.
                pidx = Index of position within neighborhood.
                pgz  = Gradient of shape (batch_size*num_channels*neibs)
                prior_result  = Shape (batch_size, num_channnels, rows, cols)
                neib_shape = Number of rows, cols in a neighborhood.
                neib_step  = Step sizes from image2neibs.
                """
                nrows, ncols = neib_shape
                rstep, cstep = neib_step
                batch_size, num_channels, rows, cols = prior_result.shape
                i = pidx // ncols
                j = pidx - (i * ncols)
                # This position does not touch some img pixels in valid mode.
                result_indices = prior_result[:, :,
                                              i:(rows - nrows + i + 1):rstep,
                                              j:(cols - ncols + j + 1):cstep, ]
                newshape = ((batch_size, num_channels) +
                            ((rows - nrows) // rstep + 1, ) +
                            ((cols - ncols) // cstep + 1, ))
                return tt.inc_subtensor(result_indices, pgz.reshape(newshape))

            indices = tt.arange(neib_shape[0] * neib_shape[1])
            pgzs = gz.dimshuffle((1, 0))
            result, _ = theano.scan(
                fn=pos2map,
                sequences=[indices, pgzs],
                outputs_info=tt.zeros(x.shape),
                non_sequences=[neib_shape, neib_step],
            )
            grad_input = result[-1]
            return [
                grad_input,
                grad_undefined(self, 1, neib_shape),
                grad_undefined(self, 2, neib_step),
            ]

        return [
            grad_not_implemented(self, 0, x),
            grad_undefined(self, 1, neib_shape),
            grad_undefined(self, 2, neib_step),
        ]

    def c_code_cache_version(self):
        return (10, )

    def perform(self, node, inp, out_, params):
        ten4, neib_shape, neib_step = inp
        (z, ) = out_
        # GpuImages2Neibs should not run this perform in DebugMode
        if type(self) != Images2Neibs:
            raise theano.gof.utils.MethodNotDefined()

        def CEIL_INTDIV(a, b):
            if a % b:
                return (a // b) + 1
            else:
                return a // b

        grid_c = -1  # number of patch in height
        grid_d = -1  # number of patch in width
        assert ten4.ndim == 4
        assert neib_shape.ndim == 1
        assert neib_shape.shape[0] == 2
        assert neib_step.ndim == 1
        assert neib_step.shape[0] == 2
        c, d = neib_shape
        step_x, step_y = neib_step
        mode = self.mode
        if step_x <= 0 or step_y <= 0:
            raise ValueError("neib_step wrong step ; values <= 0. Got " +
                             str(neib_step))
        if c <= 0 or d <= 0:
            raise ValueError("neib_shape values <=0. Got " + str(neib_shape))

        if mode == "wrap_centered":
            if (c % 2 != 1) or (d % 2 != 1):
                raise TypeError(
                    "Images2Neibs:"
                    " in mode wrap_centered need patch with odd shapes")

            if (ten4.shape[2] < c) or (ten4.shape[3] < d):
                raise TypeError(
                    "Images2Neibs: in wrap_centered mode, don't support"
                    " image shapes smaller then the patch shapes:"
                    f" neib_shape=({int(c)},{int(d)}), ten4[2:]=[{int(ten4.shape[2])},{int(ten4.shape[3])}]"
                )
            grid_c = CEIL_INTDIV(ten4.shape[2], step_x)
            grid_d = CEIL_INTDIV(ten4.shape[3], step_y)
        elif mode == "valid":
            if (ten4.shape[2] < c) or (((ten4.shape[2] - c) % step_x) != 0):
                raise TypeError(
                    f"neib_shape[0]={int(c)}, neib_step[0]={int(step_x)} and"
                    f" ten4.shape[2]={int(ten4.shape[2])} not consistent")
            if (ten4.shape[3] < d) or (((ten4.shape[3] - d) % step_y) != 0):
                raise TypeError(
                    f"neib_shape[1]={int(d)}, neib_step[1]={int(step_y)} and"
                    f" ten4.shape[3]={int(ten4.shape[3])} not consistent")
            # number of patch in height
            grid_c = 1 + ((ten4.shape[2] - c) // step_x)
            # number of patch in width
            grid_d = 1 + ((ten4.shape[3] - d) // step_y)
        elif mode == "ignore_borders":
            # number of patch in height
            grid_c = 1 + ((ten4.shape[2] - c) // step_x)
            # number of patch in width
            grid_d = 1 + ((ten4.shape[3] - d) // step_y)
        elif mode == "half":
            # This is equivalent to 'valid' with padding (c // 2, d // 2) on both sides
            # Thus the expanded image will have size (h + 2 * (c // 2), w + 2 * (d // 2))
            # Plugging these in the equation for 'valid' we get
            # h + 2 * (c // 2) - c  = h - (c % 2)
            # w + 2 * (d // 2) - c  = w - (d % 2)
            if (ten4.shape[2] < c) or (((ten4.shape[2] -
                                         (c % 2)) % step_x) != 0):
                raise TypeError(
                    f"neib_shape[0]={int(c)}, neib_step[0]={int(step_x)} and"
                    f" ten4.shape[2]={int(ten4.shape[2])} not consistent")
            if (ten4.shape[3] < d) or (((ten4.shape[3] -
                                         (d % 2)) % step_y) != 0):
                raise TypeError(
                    f"neib_shape[0]={int(d)}, neib_step[0]={int(step_y)} and"
                    f" ten4.shape[3]={int(ten4.shape[3])} not consistent")
            # number of patch in height
            grid_c = 1 + ((ten4.shape[2] - (c % 2)) // step_x)
            # number of patch in width
            grid_d = 1 + ((ten4.shape[3] - (d % 2)) // step_y)
        elif mode == "full":
            # This is equivalent to 'valid' with padding (c - 1, d - 1) on both sides
            # Thus the expanded image will have size (h + 2 * (c - 1), w + 2 * (d - 1))
            # Plugging these in the equation for 'valid' we get
            # h + 2 * (c - 1) - c  = h + c - 2
            # w + 2 * (d - 1) - c  = w + d - 2
            if (ten4.shape[2] < c) or ((
                (ten4.shape[2] + c - 2) % step_x) != 0):
                raise TypeError(
                    f"neib_shape[0]={int(c)}, neib_step[0]={int(step_x)} and"
                    f" ten4.shape[2]={int(ten4.shape[2])} not consistent")
            if (ten4.shape[3] < d) or ((
                (ten4.shape[3] + d - 2) % step_y) != 0):
                raise TypeError(
                    f"neib_shape[0]={int(d)}, neib_step[0]={int(step_y)} and"
                    f" ten4.shape[3]={int(ten4.shape[3])} not consistent")
            # number of patch in height
            grid_c = 1 + ((ten4.shape[2] + c - 2) // step_x)
            # number of patch in width
            grid_d = 1 + ((ten4.shape[3] + d - 2) // step_y)
        else:
            raise TypeError(f"Images2Neibs: unknow mode '{mode}'")
        z_dim0 = grid_c * grid_d * ten4.shape[1] * ten4.shape[0]
        z_dim1 = c * d
        z[0] = np.empty((z_dim0, z_dim1), dtype=node.outputs[0].dtype)

        nb_batch = ten4.shape[0]
        nb_stack = ten4.shape[1]
        height = ten4.shape[2]
        width = ten4.shape[3]

        wrap_centered_half_idx_shift_x = c // 2
        wrap_centered_half_idx_shift_y = d // 2
        for n in range(nb_batch):
            for s in range(nb_stack):
                # loop over the number of patch in height
                for a in range(grid_c):
                    # loop over the number of patch in width
                    for b in range(grid_d):
                        z_row = b + grid_d * (a + grid_c * (s + nb_stack * n))
                        for i in range(c):
                            ten4_2 = i + a * step_x
                            if mode == "wrap_centered":
                                ten4_2 -= wrap_centered_half_idx_shift_x
                                if ten4_2 < 0:
                                    ten4_2 += height
                                elif ten4_2 >= height:
                                    ten4_2 -= height
                            elif mode == "half":
                                ten4_2 -= wrap_centered_half_idx_shift_x
                            elif mode == "full":
                                ten4_2 -= c - 1
                            if ten4_2 < 0 or ten4_2 >= height:
                                z[0][z_row, d * i:d * i + d] = 0
                            else:
                                for j in range(d):
                                    ten4_3 = j + b * step_y
                                    if mode == "wrap_centered":
                                        ten4_3 -= wrap_centered_half_idx_shift_y
                                        if ten4_3 < 0:
                                            ten4_3 += width
                                        elif ten4_3 >= width:
                                            ten4_3 -= width
                                    elif mode == "half":
                                        ten4_3 -= wrap_centered_half_idx_shift_y
                                    elif mode == "full":
                                        ten4_3 -= d - 1
                                    z_col = j + d * i
                                    if ten4_3 < 0 or ten4_3 >= width:
                                        z[0][z_row, z_col] = 0
                                    else:
                                        z[0][z_row, z_col] = ten4[n, s, ten4_2,
                                                                  ten4_3]

    def infer_shape(self, node, input_shape):
        in_shape = input_shape[0]
        c, d = node.inputs[1]
        step_x, step_y = node.inputs[2]
        if self.mode == "wrap_centered":
            grid_c = tt.ceil_intdiv(in_shape[2], step_x)
            grid_d = tt.ceil_intdiv(in_shape[3], step_y)
        elif self.mode == "valid":
            grid_c = 1 + ((in_shape[2] - c) // step_x)
            grid_d = 1 + ((in_shape[3] - d) // step_y)
        elif self.mode == "ignore_borders":
            grid_c = 1 + ((in_shape[2] - c) // step_x)
            grid_d = 1 + ((in_shape[3] - d) // step_y)
        elif self.mode == "half":
            grid_c = 1 + ((in_shape[2] - (c % 2)) // step_x)
            grid_d = 1 + ((in_shape[3] - (d % 2)) // step_y)
        elif self.mode == "full":
            grid_c = 1 + ((in_shape[2] + c - 2) // step_x)
            grid_d = 1 + ((in_shape[3] + d - 2) // step_y)
        else:
            raise TypeError(f"Images2Neibs: unknow mode '{self.mode}'")
        z_dim0 = grid_c * grid_d * in_shape[1] * in_shape[0]
        z_dim1 = c * d
        return [(z_dim0, z_dim1)]

    def c_code(self, node, name, inp, out, sub):
        return """
#ifndef CEIL_INTDIV
#define CEIL_INTDIV(a, b) ((a/b) + ((a %% b) ? 1: 0))
#endif

        int grid_c = -1; //number of patch in height
        int grid_d = -1; //number of patch in width
        {
        if (PyArray_NDIM(%(ten4)s) != 4)
        {
            PyErr_Format(PyExc_TypeError, "ten4 wrong rank");
            %(fail)s;
        }
        if (PyArray_NDIM(%(neib_shape)s) != 1)
        {
            PyErr_Format(PyExc_TypeError, "neib_shape wrong rank");
            %(fail)s;
        }
        if ( (PyArray_DIMS(%(neib_shape)s))[0] != 2)
        {
            PyErr_Format(PyExc_TypeError, "neib_shape wrong shape ; has to"
                                          " contain 2 elements");
            %(fail)s;
        }
        if (PyArray_NDIM(%(neib_step)s) != 1)
        {
            PyErr_Format(PyExc_TypeError, "neib_step wrong rank");
            %(fail)s;
        }
        if ( (PyArray_DIMS(%(neib_step)s))[0] != 2)
        {
            PyErr_Format(PyExc_TypeError,
                         "neib_step wrong step ; has to contain 2 elements");
            %(fail)s;
        }

        // (c,d) = neib_shape
        const npy_intp c = (npy_intp) *(dtype_%(neib_shape)s*) PyArray_GETPTR1(%(neib_shape)s, 0);
        const npy_intp d = (npy_intp) *(dtype_%(neib_shape)s*) PyArray_GETPTR1(%(neib_shape)s, 1);
        // (step_x,step_y) = neib_step
        const dtype_%(neib_step)s step_x = *(dtype_%(neib_step)s*) PyArray_GETPTR1(%(neib_step)s, 0);
        const dtype_%(neib_step)s step_y = *(dtype_%(neib_step)s*) PyArray_GETPTR1(%(neib_step)s, 1);

        if (step_x <=0 || step_y <=0)
        {
            PyErr_Format(PyExc_ValueError,
                         "neib_step wrong step ; values <= 0. Got %%lld %%lld.",
                         (long long) step_x, (long long) step_y);
            %(fail)s;
        }

        if (c <=0 || d <=0)
        {
            PyErr_Format(PyExc_ValueError,
                         "neib_shape values <= 0. Got %%lld %%lld.",
                         (long long)c, (long long)d);
            %(fail)s;
        }

        if (%(mode)s == MODE_WRAP_CENTERED) {
            if (c%%2!=1 || d%%2!=1){
                PyErr_Format(PyExc_TypeError,
                             "Images2Neibs: in mode wrap_centered"
                             " need patch with odd shapes");
                %(fail)s;
            }
            if ( (PyArray_DIMS(%(ten4)s))[2] < c ||
                 (PyArray_DIMS(%(ten4)s))[3] < d)
            {
                PyErr_Format(PyExc_TypeError,
                    "Images2Neibs: in wrap_centered mode, don't support image"
                    " shapes smaller then the patch shapes:"
                    " neib_shape=(%%ld,%%ld), ten4[2:]=[%%ld,%%ld]",
                    (long int)c, (long int)d,
                    (long int)(PyArray_DIMS(%(ten4)s)[2]),
                    (long int)(PyArray_DIMS(%(ten4)s)[3]));
                %(fail)s;
            }
            grid_c = CEIL_INTDIV(((PyArray_DIMS(%(ten4)s))[2]),step_x);
            grid_d = CEIL_INTDIV(((PyArray_DIMS(%(ten4)s))[3]),step_y);

        } else if (%(mode)s == MODE_VALID) {
            if ( ((PyArray_DIMS(%(ten4)s))[2] < c) ||
                 ( (((PyArray_DIMS(%(ten4)s))[2]-c) %% step_x)!=0))
            {
                PyErr_Format(PyExc_TypeError,
                             "neib_shape[0]=%%ld, neib_step[0]=%%ld and"
                             " ten4.shape[2]=%%ld not consistent",
                             (long int)c, (long int)step_x,
                             (long int)(PyArray_DIMS(%(ten4)s)[2]));
                %(fail)s;
            }
            if ( ((PyArray_DIMS(%(ten4)s))[3] < d) ||
                 ( (((PyArray_DIMS(%(ten4)s))[3]-d) %% step_y)!=0))
            {
                PyErr_Format(PyExc_TypeError,
                             "neib_shape[1]=%%ld, neib_step[1]=%%ld and"
                             " ten4.shape[3]=%%ld not consistent",
                             (long int)d, (long int)step_y,
                             (long int)(PyArray_DIMS(%(ten4)s)[3]));
                %(fail)s;
            }
            //number of patch in height
            grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]-c)/step_x);
            //number of patch in width
            grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]-d)/step_y);
        } else if (%(mode)s == MODE_IGNORE_BORDERS) {
            //number of patch in height
            grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]-c)/step_x);
            //number of patch in width
            grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]-d)/step_y);
        } else if (%(mode)s == MODE_HALF) {
            if ( ((PyArray_DIMS(%(ten4)s))[2] < c) ||
                 ( (((PyArray_DIMS(%(ten4)s))[2]-(c%%2)) %% step_x)!=0))
            {
                PyErr_Format(PyExc_TypeError,
                             "neib_shape[0]=%%ld, neib_step[0]=%%ld and"
                             " ten4.shape[2]=%%ld not consistent",
                             (long int)c, (long int)step_x,
                             (long int)(PyArray_DIMS(%(ten4)s)[2]));
                %(fail)s;
            }
            if ( ((PyArray_DIMS(%(ten4)s))[3] < d) ||
                 ( (((PyArray_DIMS(%(ten4)s))[3]-(d%%2)) %% step_y)!=0))
            {
                PyErr_Format(PyExc_TypeError,
                             "neib_shape[1]=%%ld, neib_step[1]=%%ld and"
                             " ten4.shape[3]=%%ld not consistent",
                             (long int)d, (long int)step_y,
                             (long int)(PyArray_DIMS(%(ten4)s)[3]));
                %(fail)s;
            }
            //number of patch in height
            grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]-(c%%2))/step_x);
            //number of patch in width
            grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]-(d%%2))/step_y);
        } else if (%(mode)s == MODE_FULL) {
            if ( ((PyArray_DIMS(%(ten4)s))[2] < c) ||
                 ( (((PyArray_DIMS(%(ten4)s))[2]+c-2) %% step_x)!=0))
            {
                PyErr_Format(PyExc_TypeError,
                             "neib_shape[0]=%%ld, neib_step[0]=%%ld and"
                             " ten4.shape[2]=%%ld not consistent",
                             (long int)c, (long int)step_x,
                             (long int)(PyArray_DIMS(%(ten4)s)[2]));
                %(fail)s;
            }
            if ( ((PyArray_DIMS(%(ten4)s))[3] < d) ||
                 ( (((PyArray_DIMS(%(ten4)s))[3]+d-2) %% step_y)!=0))
            {
                PyErr_Format(PyExc_TypeError,
                             "neib_shape[1]=%%ld, neib_step[1]=%%ld and"
                             " ten4.shape[3]=%%ld not consistent",
                             (long int)d, (long int)step_y,
                             (long int)(PyArray_DIMS(%(ten4)s)[3]));
                %(fail)s;
            }
            //number of patch in height
            grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]+c-2)/step_x);
            //number of patch in width
            grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]+d-2)/step_y);
        } else {
            PyErr_Format(PyExc_TypeError,
                         "Images2Neibs: unknow mode %%d", %(mode)s);
            %(fail)s;
        }

        // new dimensions for z
        const npy_intp z_dim1 = c * d;
        const npy_intp z_dim0 =  grid_c
                            * grid_d
                            * (PyArray_DIMS(%(ten4)s))[1]
                            * (PyArray_DIMS(%(ten4)s))[0];

        if ((NULL == %(z)s)
            || ((PyArray_DIMS(%(z)s))[0] != z_dim0 )
            || ((PyArray_DIMS(%(z)s))[1] != z_dim1 )
        )
        {
            Py_XDECREF(%(z)s);
            npy_intp dims[2];
            dims[0] = z_dim0;
            dims[1] = z_dim1;

            %(z)s = (PyArrayObject*) PyArray_EMPTY(2,
                dims,
                PyArray_TYPE((PyArrayObject*) py_%(ten4)s),
                0);

            if (!%(z)s)
            {
                PyErr_SetString(PyExc_MemoryError, "failed to alloc z output");
                %(fail)s;
            }
        }
        }

        { // NESTED SCOPE

        const int nb_batch = (PyArray_DIMS(%(ten4)s))[0];
        const int nb_stack = (PyArray_DIMS(%(ten4)s))[1];
        const int height = (PyArray_DIMS(%(ten4)s))[2];
        const int width = (PyArray_DIMS(%(ten4)s))[3];

        // (c,d) = neib_shape
        const npy_intp c = (npy_intp) *(dtype_%(neib_shape)s*) PyArray_GETPTR1(%(neib_shape)s, 0);
        const npy_intp d = (npy_intp) *(dtype_%(neib_shape)s*) PyArray_GETPTR1(%(neib_shape)s, 1);
        // (step_x,step_y) = neib_step
        const npy_intp step_x = (npy_intp) *(dtype_%(neib_step)s*) PyArray_GETPTR1(%(neib_step)s, 0);
        const npy_intp step_y = (npy_intp) *(dtype_%(neib_step)s*) PyArray_GETPTR1(%(neib_step)s, 1);

        const int wrap_centered_half_idx_shift_x = c/2;
        const int wrap_centered_half_idx_shift_y = d/2;
        // Oh this is messed up...
        for (int n = 0; n < nb_batch; n++)              // loop over batches
            for (int s = 0; s < nb_stack; s++)          // loop over stacks
                for (int a = 0; a < grid_c; a++)        // loop over the number of patch in height
                    for (int b = 0; b < grid_d; b++)    // loop over the number of patch in width
                    {
                        int z_row = b + grid_d*(a + grid_c*(s + nb_stack*n));
                        for (int i = 0; i < c; i++)     // loop over c
                        {
                            int ten4_2 = i + a * step_x;
                            if (%(mode)s == MODE_WRAP_CENTERED) {
                                ten4_2 -= wrap_centered_half_idx_shift_x;
                                if ( ten4_2 < 0 ) ten4_2 += height;
                                else if (ten4_2 >= height) ten4_2 -= height;
                            } else if (%(mode)s == MODE_HALF) {
                                ten4_2 -= wrap_centered_half_idx_shift_x;
                            } else if (%(mode)s == MODE_FULL) {
                                ten4_2 -= c - 1;
                            }
                            if (ten4_2 < 0 | ten4_2 >= height) {
                                dtype_%(z)s* curr_z = (dtype_%(z)s*) PyArray_GETPTR2(%(z)s, z_row, d * i);
                                memset(curr_z, 0, d*sizeof(*curr_z));
                            } else {
                                for (int j = 0; j < d; j++)  // loop over d
                                {
                                    int ten4_3 = j + b * step_y;
                                    if (%(mode)s == MODE_WRAP_CENTERED) {
                                        ten4_3 -= wrap_centered_half_idx_shift_y;
                                        if ( ten4_3 < 0 ) ten4_3 += width;
                                        else if (ten4_3 >= width) ten4_3 -= width;
                                    } else if (%(mode)s == MODE_HALF) {
                                        ten4_3 -= wrap_centered_half_idx_shift_y;
                                    } else if (%(mode)s == MODE_FULL) {
                                        ten4_3 -= d - 1;
                                    }
                                    int z_col = j + d * i;
                                    dtype_%(z)s* curr_z = (dtype_%(z)s*) PyArray_GETPTR2(%(z)s, z_row, z_col);
                                    if (ten4_3 < 0 | ten4_3 >= width) {
                                        *curr_z = 0;
                                    } else {
                                        *curr_z = *( (dtype_%(ten4)s*) PyArray_GETPTR4(%(ten4)s, n, s, ten4_2, ten4_3));
                                    }
                                }
                            }
                        }
                    }
        } // END NESTED SCOPE
        """ % dict(
            ten4=inp[0],
            neib_shape=inp[1],
            neib_step=inp[2],
            z=out[0],
            fail=sub["fail"],
            mode=sub["params"],
        )
示例#6
0
class BaseCorrMM(gof.OpenMPOp):
    """
    Base class for `CorrMM`, `CorrMM_gradWeights` and
    `CorrMM_gradInputs`. Cannot be used directly.

    Every sub-class must define internal attribute ``_direction`` out of __init__().
    ``_direction`` must take one of following values:

     - "forward" to correlate bottom with weights and store results in top.
     - "backprop weights" to do a valid convolution of bottom with top
       (swapping the first two dimensions) and store results in weights.
     - "backprop inputs" to do a full convolution of top with weights
       (swapping the first two dimensions) and store results in bottom.

    Parameters
    ----------
    border_mode : {'valid', 'full', 'half'}
        Additionally, the padding size could be directly specified by an integer
        or a pair of integers
    subsample
        Perform subsampling of the output (default: (1, 1)).
    filter_dilation
        Perform dilated correlation (default: (1,1))
    num_groups
        Perform grouped convolutions (default: 1)
    """
    check_broadcast = False
    __props__ = ('border_mode', 'subsample', 'filter_dilation', 'num_groups')

    _direction = None

    params_type = ParamsType(
        direction=EnumList(
            ('DIRECTION_FORWARD', 'forward'),  # 0
            ('DIRECTION_BACKPROP_WEIGHTS', 'backprop weights'),  # 1
            ('DIRECTION_BACKPROP_INPUTS', 'backprop inputs')),  # 2
        dH=int64,
        dW=int64,
        dilH=int64,
        dilW=int64,
        padH=int64,
        padW=int64,
        num_groups=int64)

    def __init__(self,
                 border_mode="valid",
                 subsample=(1, 1),
                 filter_dilation=(1, 1),
                 num_groups=1,
                 openmp=None):
        super(BaseCorrMM, self).__init__(openmp=openmp)
        if isinstance(border_mode, integer_types):
            if border_mode < 0:
                raise ValueError('invalid border_mode {}, which must be a '
                                 'non-negative integer'.format(border_mode))
            border_mode = (border_mode, border_mode)
        if isinstance(border_mode, tuple):
            if len(border_mode
                   ) != 2 or border_mode[0] < 0 or border_mode[1] < 0:
                raise ValueError(
                    'invalid border_mode {}, which must be a '
                    'pair of non-negative integers'.format(border_mode))
            pad_h, pad_w = map(int, border_mode)
            border_mode = (pad_h, pad_w)
        if not ((isinstance(border_mode, tuple) and min(border_mode) >= 0)
                or border_mode in ('valid', 'full', 'half')):
            raise ValueError('invalid border_mode {}, which must be either '
                             '"valid", "full", "half", an integer or a pair of'
                             ' integers'.format(border_mode))
        self.border_mode = border_mode
        if len(subsample) != 2:
            raise ValueError("subsample must have two elements")
        if len(filter_dilation) != 2:
            raise ValueError("filter_dilation must have two elements")
        self.subsample = tuple(subsample)
        self.filter_dilation = tuple(filter_dilation)

        if not theano.config.blas.ldflags:
            # Theano will use a NumPy C implementation of [sd]gemm_ instead.
            self.blas_type = ''
        else:
            if 'openblas' in theano.config.blas.ldflags:
                self.blas_type = 'openblas'
            elif 'mkl' in theano.config.blas.ldflags:
                self.blas_type = 'mkl'
            else:
                self.blas_type = ''

        if self._direction not in [
                "forward", "backprop weights", "backprop inputs"
        ]:
            raise ValueError("_direction must be one of 'forward', "
                             "'backprop weights', 'backprop inputs'")
        if num_groups < 1:
            raise ValueError("Number of groups should be greater than 0")
        self.num_groups = num_groups

    @property
    def pad(self):
        if self.border_mode == "half":
            return (-1, -1)
        elif self.border_mode == "full":
            return (-2, -2)
        elif isinstance(self.border_mode, tuple):
            return self.border_mode
        else:
            assert self.border_mode == "valid"
            return (0, 0)

    # Direction should be converted to real enum value,
    # as it is compared to integer later in c_code_helper().
    direction = property(
        lambda self: self.params_type.enum_from_alias(self._direction))

    dH = property(lambda self: self.subsample[0])
    dW = property(lambda self: self.subsample[1])

    dilH = property(lambda self: self.filter_dilation[0])
    dilW = property(lambda self: self.filter_dilation[1])

    padH = property(lambda self: self.pad[0])
    padW = property(lambda self: self.pad[1])

    def __str__(self):
        return '%s{%s, %s, %s, %s}' % (
            self.__class__.__name__, self.border_mode, str(self.subsample),
            str(self.filter_dilation), str(self.num_groups))

    @staticmethod
    def as_common_dtype(in1, in2):
        """
        Upcast input variables if neccesary.
        """
        dtype = theano.scalar.upcast(in1.dtype, in2.dtype)
        return in1.astype(dtype), in2.astype(dtype)

    def __setstate__(self, d):
        self.__dict__.update(d)
        if not hasattr(self, 'num_groups'):
            self.num_groups = 1

    def c_support_code(self):
        ccodes = blas_headers.blas_header_text()
        if self.blas_type == 'openblas':
            ccodes += blas_headers.openblas_threads_text()
        elif self.blas_type == 'mkl':
            ccodes += blas_headers.mkl_threads_text()
        return ccodes

    def c_libraries(self):
        return ldflags()

    def c_compile_args(self):
        compile_args = ldflags(libs=False, flags=True)
        compile_args += super(BaseCorrMM, self).c_compile_args()
        return compile_args

    def c_lib_dirs(self):
        return ldflags(libs=False, libs_dir=True)

    def c_header_dirs(self):
        return ldflags(libs=False, include_dir=True)

    def c_headers(self):
        headers = ['<stdio.h>']
        headers += super(BaseCorrMM, self).c_headers()
        return headers

    def c_code_cache_version(self):
        # raise this whenever modifying any of the support_code_files
        return (7, self.openmp, blas_header_version())

    def c_support_code_apply(self, node, nodename):
        # REMEMBER TO RAISE c_code_cache_version when changing any of
        # these files
        sub = {}
        dtype = str(node.__dict__['inputs'][0].dtype)
        assert dtype in ('float32', 'float64')
        if dtype == 'float32':
            sub['gemm'] = 'sgemm_'
            sub['float_type'] = 'npy_float'
            sub['float_typenum'] = 'NPY_FLOAT'
            sub['n_bytes'] = 4
            sub['c_float_type'] = 'float'
        else:
            sub['gemm'] = 'dgemm_'
            sub['float_type'] = 'npy_double'
            sub['float_typenum'] = 'NPY_DOUBLE'
            sub['n_bytes'] = 8
            sub['c_float_type'] = 'double'

        if self.openmp:
            sub['omp_flags'] = '#pragma omp parallel for schedule(static)'
            sub['omp_get_max_threads'] = 'omp_get_max_threads()'
            sub['omp_get_thread_num'] = 'omp_get_thread_num()'

            if self.blas_type == 'openblas':
                sub['blas_set_num_threads'] = 'openblas_set_num_threads'
                sub['blas_get_num_threads'] = 'openblas_get_num_threads()'
            elif self.blas_type == 'mkl':
                sub['blas_set_num_threads'] = 'mkl_set_num_threads'
                sub['blas_get_num_threads'] = 'mkl_get_max_threads()'
            else:
                sub['blas_set_num_threads'] = ''
                sub['blas_get_num_threads'] = '0'
        else:
            sub['omp_flags'] = ''
            sub['omp_get_max_threads'] = '1'
            sub['omp_get_thread_num'] = '0'
            sub['blas_set_num_threads'] = ''
            sub['blas_get_num_threads'] = '0'

        files = [os.path.join('c_code', 'corr_gemm.c')]
        codes = [
            open(os.path.join(os.path.split(__file__)[0], f)).read()
            for f in files
        ]
        final_code = ''
        for code in codes:
            final_code += code
        return final_code % sub

    def c_code_helper(self,
                      bottom,
                      weights,
                      top,
                      sub,
                      height=None,
                      width=None):
        """
        This generates the C code for CorrMM (direction="forward"),
        CorrMM_gradWeights (direction="backprop weights"), and
        CorrMM_gradInputs (direction="backprop inputs").
        Depending on the direction, one of bottom, weights, top will
        receive the output, while the other two serve as inputs.

        :param bottom: Variable name of the input images in the forward pass,
            or the gradient of the input images in backprop wrt. inputs
        :param weights: Variable name of the filters in the forward pass,
            or the gradient of the filters in backprop wrt. weights
        :param top: Variable name of the output images / feature maps in the
            forward pass, or the gradient of the outputs in the backprop passes
        :param sub: Dictionary of substitutions useable to help generating the
            C code.
        :param height: If self.subsample[0] != 1, a variable giving the height
            of the filters for direction="backprop weights" or the height of
            the input images for direction="backprop inputs".

            If self.border_mode == 'half', a variable giving the height of the
            filters for direction="backprop weights".  Ignored otherwise.
        :param width: If self.subsample[1] != 1, a variable giving the width
            of the filters for direction="backprop weights" or the width of the
            input images for direction="backprop inputs".

            If self.border_mode == 'half', a variable giving the width of the
            filters for direction="backprop weights".  Ignored otherwise.
        """

        # When subsampling, we cannot unambiguously infer the height and width
        # of bottom and weights from top, so we require them to be given.
        # Similarly, when border_mode="half", we cannot infer the weight size.
        if height:
            height = '(*(npy_int64 *)(PyArray_DATA(%s)))' % height
        else:
            if ((self.direction != 0) and
                (self.dH != 1)) or ((self.direction == 1) and
                                    (self.padH == -1)):
                raise ValueError(
                    "height must be given for backprop with vertical sampling or border_mode='half'"
                )
            height = '-1'
        if width:
            width = '(*(npy_int64 *)(PyArray_DATA(%s)))' % width
        else:
            if ((self.direction != 0) and
                (self.dW != 1)) or ((self.direction == 1) and
                                    (self.padW == -1)):
                raise ValueError(
                    "width must be given for backprop with horizontal sampling or border_mode='half'"
                )
            width = '-1'

        return """
    // Mandatory args
    int direction = %(params)s->direction;  // forward, bprop weights, bprop inputs

    // Optional args
    int dH = %(params)s->dH;
    int dW = %(params)s->dW;
    int dilH = %(params)s->dilH;
    int dilW = %(params)s->dilW;
    int padH = %(params)s->padH;
    int padW = %(params)s->padW;
    int numgroups = %(params)s->num_groups;

    PyArrayObject * bottom = %(bottom)s;
    PyArrayObject * weights = %(weights)s;
    PyArrayObject * top = %(top)s;
    PyArrayObject * out2 = NULL;
    PyArrayObject **out = NULL;

    switch(%(params)s->direction) {
        case DIRECTION_FORWARD:
            out = &%(top)s;
            break;
        case DIRECTION_BACKPROP_WEIGHTS:
            out = &%(weights)s;
            break;
        case DIRECTION_BACKPROP_INPUTS:
            out = &%(bottom)s;
            break;
        default:
            PyErr_SetString(PyExc_ValueError, "CPU CorrMM: Invalid direction.");
            {%(fail)s}
            break;
    }

    // Obtain or infer kernel width and height
    // (we need to know it early to be able to handle auto-padding)
    int kH, kW, dil_kH, dil_kW;
    if (direction != 1) {
        // weight is an input variable, we can just read its shape
        kH = PyArray_DIMS(weights)[2];
        kW = PyArray_DIMS(weights)[3];
    }
    else {
        if (%(height)s != -1) {
            // kernel height is specified (perhaps vertical subsampling or half padding)
            kH = %(height)s;
        }
        else if (padH == -2) {
            // vertical full padding, we can infer the kernel height
            kH = (2 - PyArray_DIMS(bottom)[2] + (PyArray_DIMS(top)[2] - 1) * dH - 1)/ dilH + 1;
        }
        else {
            // explicit padding, we can infer the kernel height
            kH = (PyArray_DIMS(bottom)[2] + 2*padH - (PyArray_DIMS(top)[2] - 1) * dH - 1) / dilH +1;
        }
        if (%(width)s != -1) {
            // kernel width is specified (perhaps horizontal subsampling or half padding)
            kW = %(width)s;
        }
        else if (padW == -2) {
            kW = (2 - PyArray_DIMS(bottom)[3] + (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
        }
        else {
            kW = (PyArray_DIMS(bottom)[3] + 2*padW - (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
        }
    }

    // Implicit dilated kernel size
    dil_kH = (kH - 1) * dilH + 1;
    dil_kW = (kW - 1) * dilW + 1;

    // Auto-padding if requested
    if (padH == -1) {  // vertical half padding
        padH = dil_kH / 2;
    }
    else if (padH == -2) {  // vertical full padding
        padH = dil_kH - 1;
    }
    else if (padH < 0) {
        PyErr_SetString(PyExc_ValueError, "BaseCorrMM: padH must be >= -2");
        %(fail)s
    }
    if (padW == -1) {  // horizontal half padding
        padW = dil_kW / 2;
    }
    else if (padW == -2) {  // horizontal full padding
        padW = dil_kW - 1;
    }
    else if (padW < 0) {
        PyErr_SetString(PyExc_ValueError, "BaseCorrMM: padW must be >= -2");
        %(fail)s
    }

    // Infer output shape
    npy_intp out_dim[4];
    switch(direction) {
    case 0:  // forward pass
        // output is top: (batchsize, num_filters, height, width)
        // height and width: top = (bottom + 2*pad - ((weight-1)*dil + 1)) / sample + 1
        out_dim[0] = (npy_intp)PyArray_DIMS(bottom)[0];
        out_dim[1] = (npy_intp)PyArray_DIMS(weights)[0];
        out_dim[2] = (npy_intp)((PyArray_DIMS(bottom)[2] + 2*padH - ((PyArray_DIMS(weights)[2]-1)*dilH + 1)) / dH + 1);
        out_dim[3] = (npy_intp)((PyArray_DIMS(bottom)[3] + 2*padW - ((PyArray_DIMS(weights)[3]-1)*dilW + 1)) / dW + 1);
        if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
        {
            PyErr_Format(PyExc_ValueError,
                         "CorrMM: impossible output shape\\n"
                         "  bottom shape: %%ld x %%ld x %%ld x %%ld\\n"
                         "  weights shape: %%ld x %%ld x %%ld x %%ld\\n"
                         "  top shape: %%ld x %%ld x %%ld x %%ld\\n",
                         (long int)PyArray_DIMS(bottom)[0], (long int)PyArray_DIMS(bottom)[1],
                         (long int)PyArray_DIMS(bottom)[2], (long int)PyArray_DIMS(bottom)[3],
                         (long int)PyArray_DIMS(weights)[0], (long int)PyArray_DIMS(weights)[1],
                         (long int)PyArray_DIMS(weights)[2], (long int)PyArray_DIMS(weights)[3],
                         (long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
                         (long int)out_dim[3]);
            %(fail)s
        }
        break;
    case 1:  // backprop wrt. weights
        // output is weights: (num_filters, num_channels, height, width)
        // height and width: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1
        out_dim[0] = (npy_intp)PyArray_DIMS(top)[1];
        out_dim[1] = (npy_intp)PyArray_DIMS(bottom)[1] / numgroups;
        out_dim[2] = (npy_intp)kH;  // already inferred further above
        out_dim[3] = (npy_intp)kW;  // how convenient
        if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
        {
            PyErr_Format(PyExc_ValueError,
                         "CorrMM backprop wrt. weights: impossible output shape\\n"
                         "  bottom shape: %%ld x %%ld x %%ld x %%ld\\n"
                         "  weights shape: %%ld x %%ld x %%ld x %%ld\\n"
                         "  top shape: %%ld x %%ld x %%ld x %%ld\\n",
                         (long int)PyArray_DIMS(bottom)[0], (long int)PyArray_DIMS(bottom)[1],
                         (long int)PyArray_DIMS(bottom)[2], (long int)PyArray_DIMS(bottom)[3],
                         (long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
                         (long int)out_dim[3],
                         (long int)PyArray_DIMS(top)[0], (long int)PyArray_DIMS(top)[1],
                         (long int)PyArray_DIMS(top)[2], (long int)PyArray_DIMS(top)[3]);
            %(fail)s
        }
        break;
    case 2:  // backprop wrt. inputs
        // output is bottom: (batchsize, num_channels, height, width)
        // height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad
        out_dim[0] = (npy_intp)PyArray_DIMS(top)[0];
        out_dim[1] = (npy_intp)PyArray_DIMS(weights)[1] * numgroups;
        out_dim[2] = (npy_intp)((%(height)s != -1) ? %(height)s : (PyArray_DIMS(top)[2] - 1) * dH + (PyArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH);
        out_dim[3] = (npy_intp)((%(width)s != -1) ? %(width)s : (PyArray_DIMS(top)[3] - 1) * dW + (PyArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW);
        if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
        {
            PyErr_Format(PyExc_ValueError,
                         "CorrMM backprop wrt. inputs: impossible output shape\\n"
                         "  bottom shape: %%ld x %%ld x %%ld x %%ld\\n"
                         "  weights shape: %%ld x %%ld x %%ld x %%ld\\n"
                         "  top shape: %%ld x %%ld x %%ld x %%ld\\n",
                         (long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
                         (long int)out_dim[3],
                         (long int)PyArray_DIMS(weights)[0], (long int)PyArray_DIMS(weights)[1],
                         (long int)PyArray_DIMS(weights)[2], (long int)PyArray_DIMS(weights)[3],
                         (long int)PyArray_DIMS(top)[0], (long int)PyArray_DIMS(top)[1],
                         (long int)PyArray_DIMS(top)[2], (long int)PyArray_DIMS(top)[3]);
            %(fail)s
        }
        break;
    default:
        PyErr_SetString(PyExc_ValueError, "BaseCorrMM: direction must be 0, 1, or 2\\n");
        %(fail)s
    }

    // Prepare output array
    int typenum;
    if ( !(*out
           && PyArray_NDIM(*out)==4
           && PyArray_IS_C_CONTIGUOUS(*out)
           && PyArray_DIMS(*out)[0]==out_dim[0]
           && PyArray_DIMS(*out)[1]==out_dim[1]
           && PyArray_DIMS(*out)[2]==out_dim[2]
           && PyArray_DIMS(*out)[3]==out_dim[3]))
    {
        Py_XDECREF(*out);
        if (direction != 1) {
          typenum = PyArray_TYPE(weights);
        }
        else {
          typenum = PyArray_TYPE(bottom);
        }
        //Change to PyArray_ZEROS which is faster than PyArray_EMPTY.
        *out = (PyArrayObject*)PyArray_ZEROS(4,
                                          out_dim,
                                          typenum,
                                          0);
        if (NULL == *out)
        {
            PyErr_Format(PyExc_RuntimeError,
                    "BaseCorrMM: Failed to allocate output of %%lld x %%lld x %%lld x %%lld",
                    (long long)out_dim[0], (long long)out_dim[1], (long long)out_dim[2], (long long)out_dim[3]);
            %(fail)s
        }
    }

    // Call corrMM code
    out2 = corrMM(%(bottom)s, %(weights)s, %(top)s, direction, dH, dW, dilH, dilW, padH, padW, numgroups );
    if (out2==NULL){
       %(fail)s
    }
    assert (out2 == *out);

""" % dict(bottom=bottom,
           weights=weights,
           top=top,
           height=height,
           width=width,
           fail=sub['fail'],
           params=sub['params'])