Exemple #1
0
    def test_enum_class(self):
        # Check that invalid enum name raises exception.
        for invalid_name in ("a", "_A", "0"):
            try:
                EnumList(invalid_name)
            except AttributeError:
                pass
            else:
                raise Exception("EnumList with invalid name should faild.")

            try:
                EnumType(**{invalid_name: 0})
            except AttributeError:
                pass
            else:
                raise Exception("EnumType with invalid name should fail.")

        # Check that invalid enum value raises exception.
        try:
            EnumType(INVALID_VALUE="string is not allowed.")
        except TypeError:
            pass
        else:
            raise Exception("EnumType with invalid value should fail.")

        # Check EnumType.
        e1 = EnumType(C1=True, C2=12, C3=True, C4=-1, C5=False, C6=0.0)
        e2 = EnumType(C1=1, C2=12, C3=1, C4=-1.0, C5=0.0, C6=0)
        assert e1 == e2
        assert not (e1 != e2)
        assert hash(e1) == hash(e2)
        # Check access to attributes.
        assert len((e1.ctype, e1.C1, e1.C2, e1.C3, e1.C4, e1.C5, e1.C6)) == 7

        # Check enum with aliases.
        e1 = EnumType(A=("alpha", 0), B=("beta", 1), C=2)
        e2 = EnumType(A=("alpha", 0), B=("beta", 1), C=2)
        e3 = EnumType(A=("a", 0), B=("beta", 1), C=2)
        assert e1 == e2
        assert e1 != e3
        assert e1.filter("beta") == e1.fromalias("beta") == e1.B == 1
        assert e1.filter("C") == e1.fromalias("C") == e1.C == 2

        # Check that invalid alias (same as a constant) raises exception.
        try:
            EnumList(("A", "a"), ("B", "B"))
        except TypeError:
            EnumList(("A", "a"), ("B", "b"))
        else:
            raise Exception(
                "Enum with an alias name equal to a constant name should fail."
            )
    def test_enum_class(self):
        # Check that invalid enum name raises exception.
        for invalid_name in ('a', '_A', '0'):
            try:
                EnumList(invalid_name)
            except AttributeError:
                pass
            else:
                raise Exception('EnumList with invalid name should faild.')

            try:
                EnumType(**{invalid_name: 0})
            except AttributeError:
                pass
            else:
                raise Exception('EnumType with invalid name should fail.')

        # Check that invalid enum value raises exception.
        try:
            EnumType(INVALID_VALUE='string is not allowed.')
        except ValueError:
            pass
        else:
            raise Exception('EnumType with invalid value should fail.')

        # Check EnumType.
        e1 = EnumType(C1=True, C2=12, C3=True, C4=-1, C5=False, C6=0.0)
        e2 = EnumType(C1=1, C2=12, C3=1, C4=-1.0, C5=0.0, C6=0)
        assert e1 == e2
        assert not (e1 != e2)
        assert hash(e1) == hash(e2)
        # Check access to attributes.
        assert len((e1.ctype, e1.C1, e1.C2, e1.C3, e1.C4, e1.C5, e1.C6)) == 7
    def test_params_type_with_enums(self):
        # Test that we fail if we create a params type with common enum names inside different enum types.
        try:
            ParamsType(enum1=EnumList("A", "B", "C"),
                       enum2=EnumList("A", "B", "F"))
        except AttributeError:
            pass
        else:
            raise Exception(
                "ParamsType should fail with common enum names inside different enum types."
            )

        # Test that we fail if we create a params type with common names in both aliases and constants.
        try:
            ParamsType(
                enum1=EnumList(("A", "a"), ("B", "b")),
                enum2=EnumList(("ONE", "a"), ("TWO", "two")),
            )
        except AttributeError:
            ParamsType(
                enum1=EnumList(("A", "a"), ("B", "b")),
                enum2=EnumList(("ONE", "one"), ("TWO", "two")),
            )
        else:
            raise Exception(
                "ParamsType should fail when there are aliases with same names as some constants."
            )

        # Test that we can access enum values through wrapper directly.
        w = ParamsType(
            enum1=EnumList("A", ("B", "beta"), "C"),
            enum2=EnumList(("D", "delta"), "E", "F"),
        )
        assert w.A == 0 and w.B == 1 and w.C == 2
        assert w.D == 0 and w.E == 1 and w.F == 2
        # Test constants access through aliases.
        assert w.enum_from_alias("beta") == w.B
        assert w.enum_from_alias("delta") == w.D
        assert (w.enum_from_alias("C") == w.C
                )  # C is not an alias, so it should return a constant named C.
        # Test that other regular wrapper attributes are still available.
        assert len(w.fields) == len(w.types) == w.length
        assert w.name
Exemple #4
0
class MyOpEnumList(Op):
    __props__ = ("op_chosen",)
    params_type = EnumList(
        ("ADD", "+"),
        ("SUB", "-"),
        ("MULTIPLY", "*"),
        ("DIVIDE", "/"),
        ctype="unsigned long long",
    )

    def __init__(self, choose_op):
        assert self.params_type.ADD == 0
        assert self.params_type.SUB == 1
        assert self.params_type.MULTIPLY == 2
        assert self.params_type.DIVIDE == 3
        assert self.params_type.fromalias("+") == self.params_type.ADD
        assert self.params_type.fromalias("-") == self.params_type.SUB
        assert self.params_type.fromalias("*") == self.params_type.MULTIPLY
        assert self.params_type.fromalias("/") == self.params_type.DIVIDE
        assert self.params_type.has_alias(choose_op)
        self.op_chosen = choose_op

    def get_params(self, node):
        return self.op_chosen

    def make_node(self, a, b):
        return Apply(
            self, [scalar.as_scalar(a), scalar.as_scalar(b)], [scalar.float64()]
        )

    def perform(self, node, inputs, outputs, op):
        a, b = inputs
        (o,) = outputs
        if op == self.params_type.ADD:
            o[0] = a + b
        elif op == self.params_type.SUB:
            o[0] = a - b
        elif op == self.params_type.MULTIPLY:
            o[0] = a * b
        elif op == self.params_type.DIVIDE:
            if any(
                dtype in theano.tensor.continuous_dtypes for dtype in (a.dtype, b.dtype)
            ):
                o[0] = a / b
            else:
                o[0] = a // b
        else:
            raise NotImplementedError("Unknown op id " + str(op))
        o[0] = np.float64(o[0])

    def c_code_cache_version(self):
        return (1,)

    def c_code(self, node, name, inputs, outputs, sub):
        return """
        switch(%(op)s) {
            case ADD:
                %(o)s = %(a)s + %(b)s;
                break;
            case SUB:
                %(o)s = %(a)s - %(b)s;
                break;
            case MULTIPLY:
                %(o)s = %(a)s * %(b)s;
                break;
            case DIVIDE:
                %(o)s = %(a)s / %(b)s;
                break;
            default:
                {%(fail)s}
                break;
        }
        """ % dict(
            op=sub["params"], o=outputs[0], a=inputs[0], b=inputs[1], fail=sub["fail"]
        )
Exemple #5
0
class BaseCorr3dMM(gof.OpenMPOp):
    """
    Base class for `Corr3dMM`, `Corr3dMM_gradWeights` and
    `Corr3dMM_gradInputs`. Cannot be used directly.

    Every sub-class must define internal attribute ``_direction`` out of __init__().
    ``_direction`` must take one of following values:

     - "forward" to correlate bottom with weights and store results in top.
     - "backprop weights" to do a valid convolution of bottom with top
       (swapping the first two dimensions) and store results in weights.
     - "backprop inputs" to do a full convolution of top with weights
       (swapping the first two dimensions) and store results in bottom.

    Parameters
    ----------
    border_mode : {'valid', 'full', 'half'}
        Additionally, the padding size could be directly specified by an integer
        or a tuple of three of integers
    subsample
        Perform subsampling of the output (default: (1, 1, 1)).
    filter_dilation
        Perform dilated correlation (default: (1, 1, 1))
    num_groups
        Perform grouped convolutions (default: 1)
    """

    check_broadcast = False
    __props__ = ("border_mode", "subsample", "filter_dilation", "num_groups")

    _direction = None

    params_type = ParamsType(
        direction=EnumList(
            ("DIRECTION_FORWARD", "forward"),  # 0
            ("DIRECTION_BACKPROP_WEIGHTS", "backprop weights"),  # 1
            ("DIRECTION_BACKPROP_INPUTS", "backprop inputs"),
        ),  # 2
        dH=int64,
        dW=int64,
        dD=int64,
        dilH=int64,
        dilW=int64,
        dilD=int64,
        padH=int64,
        padW=int64,
        padD=int64,
        num_groups=int64,
    )

    def __init__(
            self,
            border_mode="valid",
            subsample=(1, 1, 1),
            filter_dilation=(1, 1, 1),
            openmp=None,
            num_groups=1,
    ):
        super().__init__(openmp=openmp)
        if isinstance(border_mode, int):
            if border_mode < 0:
                raise ValueError("invalid border_mode {}, which must be a "
                                 "non-negative integer".format(border_mode))
            border_mode = (border_mode, border_mode, border_mode)
        if isinstance(border_mode, tuple):
            if len(border_mode) != 3 or min(border_mode) < 0:
                raise ValueError(
                    "invalid border_mode {}, which must be a tuple of "
                    "three non-negative integers".format(border_mode))
            pad_h, pad_w, pad_d = map(int, border_mode)
            border_mode = (pad_h, pad_w, pad_d)
        if not ((isinstance(border_mode, tuple) and min(border_mode) >= 0)
                or border_mode in ("valid", "full", "half")):
            raise ValueError(
                "invalid border_mode {}, which must be either "
                '"valid", "full", "half", an integer or a tuple of three'
                " integers".format(border_mode))
        self.border_mode = border_mode
        if len(subsample) != 3:
            raise ValueError("subsample must have three elements")
        if len(filter_dilation) != 3:
            raise ValueError("filter_dilation must have three elements")
        self.subsample = tuple(subsample)
        self.filter_dilation = tuple(filter_dilation)
        if num_groups < 1:
            raise ValueError("Number of groups should be greater than 0")
        self.num_groups = num_groups

        if not theano.config.blas.ldflags:
            # Theano will use a NumPy C implementation of [sd]gemm_ instead.
            self.blas_type = ""
        else:
            if "openblas" in theano.config.blas.ldflags:
                self.blas_type = "openblas"
            elif "mkl" in theano.config.blas.ldflags:
                self.blas_type = "mkl"
            else:
                self.blas_type = ""

        if self._direction not in [
                "forward", "backprop weights", "backprop inputs"
        ]:
            raise ValueError("_direction must be one of 'forward', "
                             "'backprop weights', 'backprop inputs'")

    @property
    def pad(self):
        if self.border_mode == "half":
            return (-1, -1, -1)
        elif self.border_mode == "full":
            return (-2, -2, -2)
        elif isinstance(self.border_mode, tuple):
            return self.border_mode
        else:
            assert self.border_mode == "valid"
            return (0, 0, 0)

    # Direction should be converted to real enum value,
    # as it is compared to integer later in c_code_helper().
    direction = property(
        lambda self: self.params_type.enum_from_alias(self._direction))

    dH = property(lambda self: self.subsample[0])
    dW = property(lambda self: self.subsample[1])
    dD = property(lambda self: self.subsample[2])

    dilH = property(lambda self: self.filter_dilation[0])
    dilW = property(lambda self: self.filter_dilation[1])
    dilD = property(lambda self: self.filter_dilation[2])

    padH = property(lambda self: self.pad[0])
    padW = property(lambda self: self.pad[1])
    padD = property(lambda self: self.pad[2])

    def __str__(self):
        return "{}{{{}, {}, {}, {}}}".format(
            self.__class__.__name__,
            self.border_mode,
            str(self.subsample),
            str(self.filter_dilation),
            str(self.num_groups),
        )

    @staticmethod
    def as_common_dtype(in1, in2):
        """
        Upcast input variables if necessary.
        """
        dtype = theano.scalar.upcast(in1.dtype, in2.dtype)
        return in1.astype(dtype), in2.astype(dtype)

    def __setstate__(self, d):
        self.__dict__.update(d)
        if not hasattr(self, "num_groups"):
            self.num_groups = 1

    def c_support_code(self):
        ccodes = blas_headers.blas_header_text()
        if self.blas_type == "openblas":
            ccodes += blas_headers.openblas_threads_text()
        elif self.blas_type == "mkl":
            ccodes += blas_headers.mkl_threads_text()
        return ccodes

    def c_libraries(self):
        return ldflags()

    def c_compile_args(self):
        compile_args = ldflags(libs=False, flags=True)
        compile_args += super().c_compile_args()
        return compile_args

    def c_lib_dirs(self):
        return ldflags(libs=False, libs_dir=True)

    def c_header_dirs(self):
        return ldflags(libs=False, include_dir=True)

    def c_headers(self):
        headers = ["<stdio.h>"]
        headers += super().c_headers()
        return headers

    def c_code_cache_version(self):
        # raise this whenever modifying any of the support_code_files
        return (8, self.openmp, blas_header_version())

    def c_support_code_apply(self, node, nodename):
        # REMEMBER TO RAISE c_code_cache_version when changing any of
        # these files
        sub = {}
        dtype = str(node.__dict__["inputs"][0].dtype)
        assert dtype in ("float32", "float64")
        if dtype == "float32":
            sub["gemm"] = "sgemm_"
            sub["float_type"] = "npy_float"
            sub["float_typenum"] = "NPY_FLOAT"
            sub["n_bytes"] = 4
            sub["c_float_type"] = "float"
        else:
            sub["gemm"] = "dgemm_"
            sub["float_type"] = "npy_double"
            sub["float_typenum"] = "NPY_DOUBLE"
            sub["n_bytes"] = 8
            sub["c_float_type"] = "double"

        if self.openmp:
            sub["omp_flags"] = "#pragma omp parallel for schedule(static)"
            sub["omp_get_max_threads"] = "omp_get_max_threads()"
            sub["omp_get_thread_num"] = "omp_get_thread_num()"

            if self.blas_type == "openblas":
                sub["blas_set_num_threads"] = "openblas_set_num_threads"
                sub["blas_get_num_threads"] = "openblas_get_num_threads()"
            elif self.blas_type == "mkl":
                sub["blas_set_num_threads"] = "mkl_set_num_threads"
                sub["blas_get_num_threads"] = "mkl_get_max_threads()"
            else:
                sub["blas_set_num_threads"] = ""
                sub["blas_get_num_threads"] = "0"
        else:
            sub["omp_flags"] = ""
            sub["omp_get_max_threads"] = "1"
            sub["omp_get_thread_num"] = "0"
            sub["blas_set_num_threads"] = ""
            sub["blas_get_num_threads"] = "0"

        files = [os.path.join("c_code", "corr3d_gemm.c")]
        codes = [
            open(os.path.join(os.path.split(__file__)[0], f)).read()
            for f in files
        ]
        final_code = ""
        for code in codes:
            final_code += code
        return final_code % sub

    def c_code_helper(self,
                      bottom,
                      weights,
                      top,
                      sub,
                      height=None,
                      width=None,
                      depth=None):
        """
        This generates the C code for Corr3dMM (direction="forward"),
        Corr3dMM_gradWeights (direction="backprop weights"), and
        Corr3dMM_gradInputs (direction="backprop inputs").
        Depending on the direction, one of bottom, weights, top will
        receive the output, while the other two serve as inputs.

        :param bottom: Variable name of the input images in the forward pass,
            or the gradient of the input images in backprop wrt. inputs
        :param weights: Variable name of the filters in the forward pass,
            or the gradient of the filters in backprop wrt. weights
        :param top: Variable name of the output images / feature maps in the
            forward pass, or the gradient of the outputs in the backprop passes
        :param sub: Dictionary of substitutions useable to help generating the
            C code.
        :param height: If self.subsample[0] != 1, a variable giving the height
            of the filters for direction="backprop weights" or the height of
            the input images for direction="backprop inputs".

            If self.border_mode == 'half', a variable giving the height of the
            filters for direction="backprop weights".  Ignored otherwise.
        :param width: If self.subsample[1] != 1, a variable giving the width
            of the filters for direction="backprop weights" or the width of the
            input images for direction="backprop inputs".

            If self.border_mode == 'half', a variable giving the width of the
            filters for direction="backprop weights".  Ignored otherwise.
        :param depth: If self.subsample[1] != 1, a variable giving the depth
            of the filters for direction="backprop weights" or the depth of the
            input images for direction="backprop inputs".

            If self.border_mode == 'half', a variable giving the depth of the
            filters for direction="backprop weights".  Ignored otherwise.
        """

        # When subsampling, we cannot unambiguously infer the height and width
        # of bottom and weights from top, so we require them to be given.
        # Similarly, when border_mode="half", we cannot infer the weight size.
        if height:
            height = f"(*(npy_int64 *)(PyArray_DATA({height})))"
        else:
            if ((self.direction != 0) and
                (self.dH != 1)) or ((self.direction == 1) and
                                    (self.padH == -1)):
                raise ValueError(
                    "height must be given for backprop with vertical sampling or border_mode='half'"
                )
            height = "-1"
        if width:
            width = f"(*(npy_int64 *)(PyArray_DATA({width})))"
        else:
            if ((self.direction != 0) and
                (self.dW != 1)) or ((self.direction == 1) and
                                    (self.padW == -1)):
                raise ValueError(
                    "width must be given for backprop with horizontal sampling or border_mode='half'"
                )
            width = "-1"
        if depth:
            depth = f"(*(npy_int64 *)(PyArray_DATA({depth})))"
        else:
            if ((self.direction != 0) and
                (self.dD != 1)) or ((self.direction == 1) and
                                    (self.padD == -1)):
                raise ValueError(
                    "depth must be given for backprop with depth sampling or border_mode='half'"
                )
            depth = "-1"

        return """
    // Mandatory args
    int direction = %(params)s->direction;  // forward, bprop weights, bprop inputs

    // Optional args
    int dH = %(params)s->dH;
    int dW = %(params)s->dW;
    int dD = %(params)s->dD;
    int dilH = %(params)s->dilH;
    int dilW = %(params)s->dilW;
    int dilD = %(params)s->dilD;
    int padH = %(params)s->padH;
    int padW = %(params)s->padW;
    int padD = %(params)s->padD;
    int numgroups = %(params)s->num_groups;

    PyArrayObject * bottom = %(bottom)s;
    PyArrayObject * weights = %(weights)s;
    PyArrayObject * top = %(top)s;
    PyArrayObject * out2 = NULL;
    PyArrayObject **out = NULL;

    switch(%(params)s->direction) {
        case DIRECTION_FORWARD:
            out = &%(top)s;
            break;
        case DIRECTION_BACKPROP_WEIGHTS:
            out = &%(weights)s;
            break;
        case DIRECTION_BACKPROP_INPUTS:
            out = &%(bottom)s;
            break;
        default:
            PyErr_SetString(PyExc_ValueError, "CPU Corr3dMM: Invalid direction.");
            {%(fail)s}
            break;
    }

    // Obtain or infer kernel width, height and depth
    // (we need to know it early to be able to handle auto-padding)
    int kH, kW, kD, dil_kH, dil_kW, dil_kD;
    if (direction != 1) {
        // weight is an input variable, we can just read its shape
        kH = PyArray_DIMS(weights)[2];
        kW = PyArray_DIMS(weights)[3];
        kD = PyArray_DIMS(weights)[4];
    }
    else {
        if (%(height)s != -1) {
            // kernel height is specified (perhaps vertical subsampling or half padding)
            kH = %(height)s;
        }
        else if (padH == -2) {
            // vertical full padding, we can infer the kernel height
            kH = (2 - PyArray_DIMS(bottom)[2] + (PyArray_DIMS(top)[2] - 1) * dH - 1)/ dilH + 1;
        }
        else {
            // explicit padding, we can infer the kernel height
            kH = (PyArray_DIMS(bottom)[2] + 2*padH - (PyArray_DIMS(top)[2] - 1) * dH - 1) / dilH +1;
        }
        if (%(width)s != -1) {
            kW = %(width)s;
        }
        else if (padW == -2) {
            kW = (2 - PyArray_DIMS(bottom)[3] + (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
        }
        else {
            kW = (PyArray_DIMS(bottom)[3] + 2*padW - (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
        }
        if (%(depth)s != -1) {
            kD = %(depth)s;
        }
        else if (padD == -2) {
            kD = (2 - PyArray_DIMS(bottom)[4] + (PyArray_DIMS(top)[4] - 1) * dD - 1) / dilD + 1;
        }
        else {
            kD = (PyArray_DIMS(bottom)[4] + 2*padD - (PyArray_DIMS(top)[4] - 1) * dD - 1) / dilD + 1;
        }
    }

    // Implicit dilated kernel size
    dil_kH = (kH - 1) * dilH + 1;
    dil_kW = (kW - 1) * dilW + 1;
    dil_kD = (kD - 1) * dilD + 1;

    // Auto-padding if requested
    if (padH == -1) {  // vertical half padding
        padH = dil_kH / 2;
    }
    else if (padH == -2) {  // vertical full padding
        padH = dil_kH - 1;
    }
    else if (padH < 0) {
        PyErr_SetString(PyExc_ValueError, "BaseCorr3dMM: padH must be >= -2");
        %(fail)s
    }
    if (padW == -1) {  // horizontal half padding
        padW = dil_kW / 2;
    }
    else if (padW == -2) {  // horizontal full padding
        padW = dil_kW - 1;
    }
    else if (padW < 0) {
        PyErr_SetString(PyExc_ValueError, "BaseCorr3dMM: padW must be >= -2");
        %(fail)s
    }
    if (padD == -1) {  // depth half padding
        padD = dil_kD / 2;
    }
    else if (padD == -2) {  // depth full padding
        padD = dil_kD - 1;
    }
    else if (padD < 0) {
        PyErr_SetString(PyExc_ValueError, "BaseCorr3dMM: padD must be >= -2");
        %(fail)s
    }

    // Infer output shape
    npy_intp out_dim[5];
    switch(direction) {
    case 0:  // forward pass
        // output is top: (batchsize, num_filters, height, width, depth)
        // height and width: top = (bottom + 2*pad - ((weight-1)*dil + 1)) / sample + 1
        out_dim[0] = (npy_intp)PyArray_DIMS(bottom)[0];
        out_dim[1] = (npy_intp)PyArray_DIMS(weights)[0];
        out_dim[2] = (npy_intp)((PyArray_DIMS(bottom)[2] + 2*padH - ((PyArray_DIMS(weights)[2]-1)*dilH + 1)) / dH + 1);
        out_dim[3] = (npy_intp)((PyArray_DIMS(bottom)[3] + 2*padW - ((PyArray_DIMS(weights)[3]-1)*dilW + 1)) / dW + 1);
        out_dim[4] = (npy_intp)((PyArray_DIMS(bottom)[4] + 2*padD - ((PyArray_DIMS(weights)[4]-1)*dilD + 1)) / dD + 1);
        if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0 || out_dim[4] <= 0)
        {
            PyErr_Format(PyExc_ValueError,
                         "Corr3dMM: impossible output shape\\n"
                         "  bottom shape: %%ld x %%ld x %%ld x %%ld x %%ld\\n"
                         "  weights shape: %%ld x %%ld x %%ld x %%ld x %%ld\\n"
                         "  top shape: %%ld x %%ld x %%ld x %%ld x %%ld\\n",
                         (long int)PyArray_DIMS(bottom)[0], (long int)PyArray_DIMS(bottom)[1],
                         (long int)PyArray_DIMS(bottom)[2], (long int)PyArray_DIMS(bottom)[3],
                         (long int)PyArray_DIMS(bottom)[4],
                         (long int)PyArray_DIMS(weights)[0], (long int)PyArray_DIMS(weights)[1],
                         (long int)PyArray_DIMS(weights)[2], (long int)PyArray_DIMS(weights)[3],
                         (long int)PyArray_DIMS(weights)[4],
                         (long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
                         (long int)out_dim[3], (long int)out_dim[4]);
            %(fail)s
        }
        break;
    case 1:  // backprop wrt. weights
        // output is weights: (num_filters, num_channels, height, width, depth)
        // height and width: weights = (bottom + 2*pad - (top - 1) * sample - 1) / dil + 1
        out_dim[0] = (npy_intp)PyArray_DIMS(top)[1];
        out_dim[1] = (npy_intp)PyArray_DIMS(bottom)[1] / numgroups;
        out_dim[2] = (npy_intp)kH;  // already inferred further above
        out_dim[3] = (npy_intp)kW;  // how convenient
        out_dim[4] = (npy_intp)kD;
        if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0 || out_dim[4] <= 0)
        {
            PyErr_Format(PyExc_ValueError,
                         "Corr3dMM backprop wrt. weights: impossible output shape\\n"
                         "  bottom shape: %%ld x %%ld x %%ld x %%ld x %%ld\\n"
                         "  weights shape: %%ld x %%ld x %%ld x %%ld x %%ld\\n"
                         "  top shape: %%ld x %%ld x %%ld x %%ld x %%ld\\n",
                         (long int)PyArray_DIMS(bottom)[0], (long int)PyArray_DIMS(bottom)[1],
                         (long int)PyArray_DIMS(bottom)[2], (long int)PyArray_DIMS(bottom)[3],
                         (long int)PyArray_DIMS(bottom)[4],
                         (long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
                         (long int)out_dim[3], (long int)out_dim[4],
                         (long int)PyArray_DIMS(top)[0], (long int)PyArray_DIMS(top)[1],
                         (long int)PyArray_DIMS(top)[2], (long int)PyArray_DIMS(top)[3],
                         (long int)PyArray_DIMS(top)[4]);
            %(fail)s
        }
        break;
    case 2:  // backprop wrt. inputs
        // output is bottom: (batchsize, num_channels, height, width, depth)
        // height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad
        out_dim[0] = (npy_intp)PyArray_DIMS(top)[0];
        out_dim[1] = (npy_intp)PyArray_DIMS(weights)[1] * numgroups;
        out_dim[2] = (npy_intp)((%(height)s != -1) ? %(height)s : (PyArray_DIMS(top)[2] - 1) * dH + (PyArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH);
        out_dim[3] = (npy_intp)((%(width)s != -1) ? %(width)s : (PyArray_DIMS(top)[3] - 1) * dW + (PyArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW);
        out_dim[4] = (npy_intp)((%(depth)s != -1) ? %(depth)s : (PyArray_DIMS(top)[4] - 1) * dD + (PyArray_DIMS(weights)[4]-1)*dilD + 1 - 2*padD);
        if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0 || out_dim[4] <= 0)
        {
            PyErr_Format(PyExc_ValueError,
                         "Corr3dMM backprop wrt. inputs: impossible output shape\\n"
                         "  bottom shape: %%ld x %%ld x %%ld x %%ld x %%ld\\n"
                         "  weights shape: %%ld x %%ld x %%ld x %%ld x %%ld\\n"
                         "  top shape: %%ld x %%ld x %%ld x %%ld x %%ld\\n",
                         (long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
                         (long int)out_dim[3], (long int)out_dim[4],
                         (long int)PyArray_DIMS(weights)[0], (long int)PyArray_DIMS(weights)[1],
                         (long int)PyArray_DIMS(weights)[2], (long int)PyArray_DIMS(weights)[3],
                         (long int)PyArray_DIMS(weights)[4],
                         (long int)PyArray_DIMS(top)[0], (long int)PyArray_DIMS(top)[1],
                         (long int)PyArray_DIMS(top)[2], (long int)PyArray_DIMS(top)[3],
                         (long int)PyArray_DIMS(top)[4]);
            %(fail)s
        }
        break;
    default:
        PyErr_SetString(PyExc_ValueError, "BaseCorr3dMM: direction must be 0, 1, or 2\\n");
        %(fail)s
    }

    // Prepare output array
    int typenum;
    if ( !(*out
           && PyArray_NDIM(*out)==4
           && PyArray_IS_C_CONTIGUOUS(*out)
           && PyArray_DIMS(*out)[0]==out_dim[0]
           && PyArray_DIMS(*out)[1]==out_dim[1]
           && PyArray_DIMS(*out)[2]==out_dim[2]
           && PyArray_DIMS(*out)[3]==out_dim[3]
           && PyArray_DIMS(*out)[4]==out_dim[4]))
    {
        Py_XDECREF(*out);
        if (direction != 1) {
          typenum = PyArray_TYPE(weights);
        }
        else {
          typenum = PyArray_TYPE(bottom);
        }
        //Change to PyArray_ZEROS which is faster than PyArray_EMPTY.
        *out = (PyArrayObject*)PyArray_ZEROS(5,
                                          out_dim,
                                          typenum,
                                          0);
        if (NULL == *out)
        {
            PyErr_Format(PyExc_RuntimeError,
                    "BaseCorr3dMM: Failed to allocate output of %%lld x %%lld x %%lld x %%lld x %%lld",
                    (long long)out_dim[0], (long long)out_dim[1],
                    (long long)out_dim[2], (long long)out_dim[3], (long long)out_dim[4]);
            %(fail)s
        }
    }

    // Call corr3dMM code
    out2 = corr3dMM(%(bottom)s, %(weights)s, %(top)s, direction,
                    dH, dW, dD, dilH, dilW, dilD, padH, padW, padD,
                    numgroups);
    if (out2==NULL){
       %(fail)s
    }
    assert (out2 == *out);

""" % dict(
            bottom=bottom,
            weights=weights,
            top=top,
            height=height,
            width=width,
            depth=depth,
            fail=sub["fail"],
            params=sub["params"],
        )
class MyOpEnumList(Op):
    __props__ = ('op_chosen', )
    params_type = EnumList('ADD',
                           'SUB',
                           'MULTIPLY',
                           'DIVIDE',
                           ctype='unsigned long long')

    def __init__(self, choose_op):
        assert self.params_type.ADD == 0
        assert self.params_type.SUB == 1
        assert self.params_type.MULTIPLY == 2
        assert self.params_type.DIVIDE == 3
        op_to_const = {
            '+': self.params_type.ADD,
            '-': self.params_type.SUB,
            '*': self.params_type.MULTIPLY,
            '/': self.params_type.DIVIDE
        }
        self.op_chosen = op_to_const[choose_op]

    def get_params(self, node):
        return self.op_chosen

    def make_node(self, a, b):
        return Apply(
            self,
            [scalar.as_scalar(a), scalar.as_scalar(b)], [scalar.float64()])

    def perform(self, node, inputs, outputs, op):
        a, b = inputs
        o, = outputs
        if op == self.params_type.ADD:
            o[0] = a + b
        elif op == self.params_type.SUB:
            o[0] = a - b
        elif op == self.params_type.MULTIPLY:
            o[0] = a * b
        elif op == self.params_type.DIVIDE:
            if any(dtype in theano.tensor.continuous_dtypes
                   for dtype in (a.dtype, b.dtype)):
                o[0] = a / b
            else:
                o[0] = a // b
        else:
            raise NotImplementedError('Unknown op id ' + str(op))

    def c_code_cache_version(self):
        return (1, )

    def c_code(self, node, name, inputs, outputs, sub):
        return """
        switch(%(op)s) {
            case ADD:
                %(o)s = %(a)s + %(b)s;
                break;
            case SUB:
                %(o)s = %(a)s - %(b)s;
                break;
            case MULTIPLY:
                %(o)s = %(a)s * %(b)s;
                break;
            case DIVIDE:
                %(o)s = %(a)s / %(b)s;
                break;
            default:
                {%(fail)s}
                break;
        }
        """ % dict(op=sub['params'],
                   o=outputs[0],
                   a=inputs[0],
                   b=inputs[1],
                   fail=sub['fail'])