Beispiel #1
0
 def c_code_contiguous_disabled(self, node, name, inp, out, sub):
     (x, ) = inp
     (z, ) = out
     if not config.lib__amblibm or node.inputs[0].dtype != node.outputs[
             0].dtype:
         raise MethodNotDefined()
     dtype = node.inputs[0].dtype
     if dtype == "float32" and self.amd_float32 is not None:
         dtype = "float"
         fct = "amd_vrsa_expf"
     elif dtype == "float64" and self.amd_float64 is not None:
         dtype = "double"
         fct = "amd_vrda_exp"
     else:
         raise MethodNotDefined()
     return ("""
     npy_intp n = PyArray_SIZE(%(z)s);
     %(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s);
     %(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s);
     // We block to keep the data in l1
     // normal l1 size = 32k: 32k/2(input + output)/8(nb bytes of double)=2k
     // We stay bellow the 2k limit to let space for
     // This is faster than the not blocking version
     for(int i=0;i<n;i+=2048){
         npy_intp nb = (n-i<2048)?n-i:2048;
         for(int j=0;j<nb;j++){
             z[i+j] = -x[i+j];
         }
         %(fct)s(nb, z+i, z+i);
         for(int j=0;j<nb;j++){
             z[i+j] = 1.0 /(1.0+z[i+j]);
         }
     }
     """ % locals())
     raise MethodNotDefined()
Beispiel #2
0
 def get_params(self, node: Apply) -> Params:
     """Try to detect params from the op if `Op.params_type` is set to a `ParamsType`."""
     if hasattr(self, "params_type") and isinstance(self.params_type,
                                                    ParamsType):
         wrapper = self.params_type
         if not all(hasattr(self, field) for field in wrapper.fields):
             # Let's print missing attributes for debugging.
             not_found = tuple(field for field in wrapper.fields
                               if not hasattr(self, field))
             raise AttributeError(
                 f"{type(self).__name__}: missing attributes {not_found} for ParamsType."
             )
         # ParamsType.get_params() will apply filtering to attributes.
         return self.params_type.get_params(self)
     raise MethodNotDefined("get_params")
Beispiel #3
0
 def __hide(*args):
     raise MethodNotDefined()