def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]: from dace.codegen.tools.type_inference import infer_expr_type result = {} # Add map params for p, rng in zip(self._map.params, self._map.range): result[p] = dtypes.result_type_of(infer_expr_type(rng[0], symbols), infer_expr_type(rng[1], symbols)) # Add dynamic inputs dyn_inputs = set(c for c in self.in_connectors if not c.startswith('IN_')) # Try to get connector type from connector for e in state.in_edges(self): if e.dst_conn in dyn_inputs: result[e.dst_conn] = (self.in_connectors[e.dst_conn] or sdfg.arrays[e.data.data].dtype) return result
def testSymbolic(self): # Define some sympy symbols to work with n = sp.Symbol('n') m = sp.Symbol('m') defined_symbols = {'n': dtypes.typeclass(np.float64)} inf_symbol = type_inference.infer_expr_type(n + 5, defined_symbols) self.assertEqual(inf_symbol, dtypes.typeclass(np.float64)) defined_symbols = {'n': dtypes.typeclass(np.int8)} inf_symbol = type_inference.infer_expr_type(n * 5, defined_symbols) self.assertEqual(inf_symbol, dtypes.typeclass(int)) defined_symbols = {'n': dtypes.typeclass(np.int8)} inf_symbol = type_inference.infer_expr_type(n * 5.0, defined_symbols) self.assertEqual(inf_symbol, dtypes.typeclass(float)) defined_symbols = {'n': dtypes.typeclass(np.int8)} inf_symbol = type_inference.infer_expr_type(n * 5.01, defined_symbols) self.assertEqual(inf_symbol, dtypes.typeclass(float)) defined_symbols = {'n': dtypes.typeclass(np.int8), 'm': dtypes.typeclass(np.float32)} inf_symbol = type_inference.infer_expr_type(n * m + n, defined_symbols) self.assertEqual(inf_symbol, dtypes.typeclass(np.float32))
def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]: from dace.codegen.tools.type_inference import infer_expr_type result = {} # Add PE index result[self._consume.pe_index] = infer_expr_type( self._consume.num_pes, symbols) # Add dynamic inputs dyn_inputs = set(c for c in self.in_connectors if not c.startswith('IN_')) # TODO: Get connector type from connector for e in state.in_edges(self): if e.dst_conn in dyn_inputs: result[e.dst_conn] = sdfg.arrays[e.data.data].dtype return result