Exemple #1
0
 def __call__(self, expr):
     self.is_constant = {}
     for variable in self.free_variables:
         self.is_constant[variable] = False
     self.node_stack.append(expr)
     CombineMapper.__call__(self, expr)
     return self.is_constant
Exemple #2
0
    def map_call(self, expr):
        name = expr.function.name
        if name == "fromreal":
            arg, = expr.parameters
            base_dtype = self.rec(arg)
            tgt_real_dtype = (np.float32(0) + base_dtype.type(0)).dtype
            assert tgt_real_dtype.kind == "f"
            if tgt_real_dtype == np.float32:
                return np.dtype(np.complex64)
            elif tgt_real_dtype == np.float64:
                return np.dtype(np.complex128)
            else:
                raise RuntimeError("unexpected complex type")

        elif name in ["imag", "real", "abs", "dble"]:
            arg, = expr.parameters
            base_dtype = self.rec(arg)

            if base_dtype == np.complex128:
                return np.dtype(np.float64)
            elif base_dtype == np.complex64:
                return np.dtype(np.float32)
            else:
                return base_dtype

        else:
            return CombineMapper.map_call(self, expr)
Exemple #3
0
    def map_call(self, expr):
        name = expr.function.name
        if name == "fromreal":
            arg, = expr.parameters
            base_dtype = self.rec(arg)
            tgt_real_dtype = (np.float32(0) + base_dtype.type(0)).dtype
            assert tgt_real_dtype.kind == "f"
            if tgt_real_dtype == np.float32:
                return np.dtype(np.complex64)
            elif tgt_real_dtype == np.float64:
                return np.dtype(np.complex128)
            else:
                raise RuntimeError("unexpected complex type")

        elif name in ["imag", "real", "abs", "dble"]:
            arg, = expr.parameters
            base_dtype = self.rec(arg)

            if base_dtype == np.complex128:
                return np.dtype(np.float64)
            elif base_dtype == np.complex64:
                return np.dtype(np.float32)
            else:
                return base_dtype

        else:
            return CombineMapper.map_call(self, expr)
Exemple #4
0
 def map_call(self, expr):
     if self.include_calls == "descend_args":
             return self.combine(
                     [self.rec(child) for child in expr.parameters])
     elif self.include_calls:
         return set([expr])
     else:
         return CombineMapper.map_call(self, expr)
Exemple #5
0
 def map_common_subexpression_uncached(self, expr):
     if self.include_cses:
         return set([expr])
     else:
         return CombineMapper.map_common_subexpression(self, expr)
Exemple #6
0
 def map_subscript(self, expr):
     if self.include_subscripts:
         return set([expr])
     else:
         return CombineMapper.map_subscript(self, expr)
Exemple #7
0
 def map_lookup(self, expr):
     if self.include_lookups:
         return set([expr])
     else:
         return CombineMapper.map_lookup(self, expr)
Exemple #8
0
 def rec(self, expr):
     self.node_stack.append(expr)
     return CombineMapper.rec(self, expr)