コード例 #1
0
 def __call__(self, expr):
     self.is_constant = {}
     for variable in self.free_variables:
         self.is_constant[variable] = False
     self.node_stack.append(expr)
     CombineMapper.__call__(self, expr)
     return self.is_constant
コード例 #2
0
    def map_call(self, expr):
        name = expr.function.name
        if name == "fromreal":
            arg, = expr.parameters
            base_dtype = self.rec(arg)
            tgt_real_dtype = (np.float32(0) + base_dtype.type(0)).dtype
            assert tgt_real_dtype.kind == "f"
            if tgt_real_dtype == np.float32:
                return np.dtype(np.complex64)
            elif tgt_real_dtype == np.float64:
                return np.dtype(np.complex128)
            else:
                raise RuntimeError("unexpected complex type")

        elif name in ["imag", "real", "abs", "dble"]:
            arg, = expr.parameters
            base_dtype = self.rec(arg)

            if base_dtype == np.complex128:
                return np.dtype(np.float64)
            elif base_dtype == np.complex64:
                return np.dtype(np.float32)
            else:
                return base_dtype

        else:
            return CombineMapper.map_call(self, expr)
コード例 #3
0
ファイル: translate.py プロジェクト: simudream/pyopencl
    def map_call(self, expr):
        name = expr.function.name
        if name == "fromreal":
            arg, = expr.parameters
            base_dtype = self.rec(arg)
            tgt_real_dtype = (np.float32(0) + base_dtype.type(0)).dtype
            assert tgt_real_dtype.kind == "f"
            if tgt_real_dtype == np.float32:
                return np.dtype(np.complex64)
            elif tgt_real_dtype == np.float64:
                return np.dtype(np.complex128)
            else:
                raise RuntimeError("unexpected complex type")

        elif name in ["imag", "real", "abs", "dble"]:
            arg, = expr.parameters
            base_dtype = self.rec(arg)

            if base_dtype == np.complex128:
                return np.dtype(np.float64)
            elif base_dtype == np.complex64:
                return np.dtype(np.float32)
            else:
                return base_dtype

        else:
            return CombineMapper.map_call(self, expr)
コード例 #4
0
ファイル: dependency.py プロジェクト: hpc12/lec11-demo
 def map_call(self, expr):
     if self.include_calls == "descend_args":
             return self.combine(
                     [self.rec(child) for child in expr.parameters])
     elif self.include_calls:
         return set([expr])
     else:
         return CombineMapper.map_call(self, expr)
コード例 #5
0
ファイル: dependency.py プロジェクト: hpc12/lec11-demo
 def map_common_subexpression_uncached(self, expr):
     if self.include_cses:
         return set([expr])
     else:
         return CombineMapper.map_common_subexpression(self, expr)
コード例 #6
0
ファイル: dependency.py プロジェクト: hpc12/lec11-demo
 def map_subscript(self, expr):
     if self.include_subscripts:
         return set([expr])
     else:
         return CombineMapper.map_subscript(self, expr)
コード例 #7
0
ファイル: dependency.py プロジェクト: hpc12/lec11-demo
 def map_lookup(self, expr):
     if self.include_lookups:
         return set([expr])
     else:
         return CombineMapper.map_lookup(self, expr)
コード例 #8
0
 def rec(self, expr):
     self.node_stack.append(expr)
     return CombineMapper.rec(self, expr)