Exemplo n.º 1
0
 def make_node(self):
     return gof.Apply(
         self,
         [],
         [
             theano.Variable(Generic()),
             tensor(self.dtype, broadcastable=self.broadcastable),
         ],
     )
Exemplo n.º 2
0
class SearchsortedOp(theano.Op):
    """Wrapper of numpy.searchsorted.

    For full documentation, see :func:`searchsorted`.

    See Also
    --------
    searchsorted : numpy-like function to use the SearchsortedOp

    """

    params_type = Generic()
    __props__ = ("side", )

    def __init__(self, side='left'):
        if side == 'left' or side == 'right':
            self.side = side
        else:
            raise ValueError('\'%(side)s\' is an invalid value for keyword \'side\''
                             % locals())

    def get_params(self, node):
        return self.side

    def make_node(self, x, v, sorter=None):
        x = basic.as_tensor(x, ndim=1)
        v = basic.as_tensor(v)
        out_type = v.type.clone(dtype='int64')
        if sorter is None:
            return theano.Apply(self, [x, v], [out_type()])
        else:
            sorter = basic.as_tensor(sorter, ndim=1)
            if (theano.configdefaults.python_int_bitwidth() == 32 and
                    sorter.dtype == 'int64'):
                raise TypeError(
                    "numpy.searchsorted with Python 32bit do not support a"
                    " sorter of int64.")
            if sorter.type not in basic.int_vector_types:
                raise TypeError('sorter must be an integer vector',
                                sorter.type)
            return theano.Apply(self, [x, v, sorter], [out_type()])

    def infer_shape(self, node, shapes):
        return [shapes[1]]

    def perform(self, node, inputs, output_storage, params):
        x = inputs[0]
        v = inputs[1]
        if len(node.inputs) == 3:
            sorter = inputs[2]
        else:
            sorter = None
        z = output_storage[0]

        z[0] = np.searchsorted(x, v, side=params, sorter=sorter).astype(
            node.outputs[0].dtype)

    def c_support_code_struct(self, node, name):
        return """
            int right_%(name)s;
        """ % locals()

    def c_init_code_struct(self, node, name, sub):
        side = sub['params']
        fail = sub['fail']
        return """
            PyObject* tmp_%(name)s = PyUnicode_FromString("right");
            if (tmp_%(name)s == NULL)
                %(fail)s;
            right_%(name)s = PyUnicode_Compare(%(side)s, tmp_%(name)s);
            Py_DECREF(tmp_%(name)s);
        """ % locals()

    def c_code(self, node, name, inames, onames, sub):
        sorter = None
        if len(node.inputs) == 3:
            x, v, sorter = inames
        else:
            x, v = inames
        if not sorter:
            sorter = "NULL"
        z, = onames
        fail = sub['fail']

        return """
            Py_XDECREF(%(z)s);
            %(z)s = (PyArrayObject*) PyArray_SearchSorted(%(x)s, (PyObject*) %(v)s,
                                                          right_%(name)s ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) %(sorter)s);
            if (!%(z)s)
                %(fail)s;
            if (PyArray_TYPE(%(z)s) != NPY_INT64){
                PyObject * tmp = PyArray_Cast(%(z)s, NPY_INT64);
                Py_XDECREF(%(z)s);
                %(z)s = (PyArrayObject*) tmp;
            }
        """ % locals()

    def c_code_cache_version(self):
        return (2,)

    def grad(self, inputs, output_gradients):
        num_ins = len(inputs)
        if num_ins == 3:
            x, v, sorter = inputs
        else:
            x, v = inputs

        x_grad = gradient._float_zeros_like(x)
        v_grad = gradient._float_zeros_like(v)
        if num_ins == 3:
            return [x_grad, v_grad, disconnected_type()]
        else:
            return [x_grad, v_grad]
Exemplo n.º 3
0
 def make_node(self, path):
     if isinstance(path, str):
         path = Constant(Generic(), path)
     return gof.Apply(
         self, [path],
         [tensor(self.dtype, broadcastable=self.broadcastable)])
Exemplo n.º 4
0
 def make_node(self, request, data):
     return gof.Apply(self, [request, data], [theano.Variable(Generic())])
Exemplo n.º 5
0
 def make_node(self, data):
     return gof.Apply(self, [data],
                      [theano.Variable(Generic()),
                       data.type()])