def visit_UnaryOpExpr(self, node: gt_ir.Node, **kwargs): self.generic_visit(node, **kwargs) assert node.arg.data_type is not gt_ir.DataType.AUTO if node.op.value in [gt_ir.UnaryOperator.NOT]: node.data_type = gt_ir.DataType.from_dtype(bool) else: node.data_type = node.arg.data_type
def visit_BinOpExpr(self, node: gt_ir.Node, **kwargs): self.generic_visit(node, **kwargs) assert node.lhs.data_type is not gt_ir.DataType.AUTO assert node.rhs.data_type is not gt_ir.DataType.AUTO if node.op.value in [ gt_ir.BinaryOperator.OR, gt_ir.BinaryOperator.EQ, gt_ir.BinaryOperator.NE, gt_ir.BinaryOperator.LT, gt_ir.BinaryOperator.LE, gt_ir.BinaryOperator.GT, gt_ir.BinaryOperator.GE, ]: node.data_type = gt_ir.DataType.from_dtype(bool) else: node.data_type = gt_ir.DataType.merge(node.lhs.data_type, node.rhs.data_type)
def visit_VarRef(self, node: gt_ir.Node, apply_block_symbols={}, **kwargs): self.generic_visit(node, **kwargs) if node.name in apply_block_symbols: var_decl = apply_block_symbols[node.name] else: var_decl = self.vars[node.name] if var_decl.data_type == gt_ir.DataType.AUTO: var_decl.data_type = node.data_type else: node.data_type = var_decl.data_type
def visit_FieldRef(self, node: gt_ir.Node, **kwargs): self.generic_visit(node, **kwargs) if self.fields[node.name].data_type == gt_ir.DataType.AUTO: self.fields[node.name].data_type = node.data_type else: node.data_type = self.fields[node.name].data_type