def add_function_args(cls, function: mast.Function, dtypes: List[mast.Type], names: Optional[List[str]] = None, positions: Optional[List[int]] = None): """ Adds arguments to *function*. :arg dtypes: Types of the arguments to be added to the function. :arg names: Names of the arguments to be added to the function. :arg positions: Positions where the arguments are to be inserted. """ if names is None: names = [cls.name_gen("fnarg") for _ in dtypes] if function.args is None: function.args = [] if positions is None: positions = list( range(len(function.args), len(function.args) + len(dtypes))) args = [] for name, dtype, pos in zip(names, dtypes, positions): arg = mast.SsaId(name) function.args.insert(pos, mast.NamedArgument(arg, dtype)) args.append(arg) return args
def _insert_op_in_block(self, op_results: List[Optional[Union[mast.SsaId, str]]], op): new_op_results = [] for op_result in op_results: if op_result is None: op_result = self.name_gen("ssa") if isinstance(op_result, str): result = mast.SsaId(op_result) new_op_results.append(result) if self.block is None: raise ValueError("Not within any block to append") self.block.body.insert( self.position, mast.Operation(result_list=new_op_results, op=op)) self.position += 1 if len(new_op_results) == 1: return new_op_results[0] elif len(new_op_results) > 1: return new_op_results else: return
def add_block_args(self, dtypes: List[ast.Type], names: Optional[List[str]] = None, positions: Optional[List[int]] = None): """ Adds arguments to the current block. :arg dtypes: Types of the arguments to be added to the block. :arg names: Names of the arguments to be added to the block. :arg positions: Positions where the arguments are to be inserted. """ if names is None: names = [self.name_gen("bbarg") for _ in dtypes] if self.block.label is None: self.block.label = ast.BlockLabel(ast.BlockId(self.name_gen("bb")), [], []) if self.block.label.arg_ids is None: self.block.label.arg_ids = [] assert self.block.label.arg_types is None self.block.label.arg_types = [] assert (self.block.label.arg_types is not None and self.block.label.arg_ids is not None) if positions is None: positions = list( range(len(self.block.label.arg_types), len(self.block.label.arg_types) + len(dtypes))) args = [] for name, dtype, pos in zip(names, dtypes, positions): arg = ast.SsaId(name) self.block.label.arg_ids.insert(pos, arg) self.block.label.arg_types.insert(pos, dtype) args.append(arg) return args