def visit_call(self, call): if call.op.name == "add": # Annotate begin at args if self.in_compiler == 1: lhs = compiler_begin(super().visit(call.args[0]), "ccompiler") rhs = compiler_begin(super().visit(call.args[1]), "ccompiler") op = relay.add(lhs, rhs) self.in_compiler = 2 return op elif call.op.name == "subtract": if self.in_compiler == 1: lhs = super().visit(call.args[0]) rhs = super().visit(call.args[1]) if isinstance(lhs, relay.expr.Var): lhs = compiler_begin(lhs, "ccompiler") if isinstance(rhs, relay.expr.Var): rhs = compiler_begin(rhs, "ccompiler") return relay.subtract(lhs, rhs) elif call.op.name == "multiply": # Annotate end at output self.in_compiler = 1 lhs = super().visit(call.args[0]) rhs = super().visit(call.args[1]) if isinstance(lhs, relay.expr.Var): lhs = compiler_begin(lhs, "ccompiler") if isinstance(rhs, relay.expr.Var): rhs = compiler_begin(rhs, "ccompiler") op = relay.multiply(lhs, rhs) if self.in_compiler == 2: op = compiler_end(op, "ccompiler") self.in_compiler = 0 return op return super().visit_call(call)
def visit_call(self, call): new_args = [] for arg in call.args: ann = compiler_begin(self.visit(arg), "ccompiler") new_args.append(ann) new_call = relay.Call(call.op, new_args) return compiler_end(new_call, "ccompiler")
def visit_call(self, call): op_name = call.op.name if op_name in annotator.op_list: new_args = [] for arg in call.args: ann = compiler_begin(super().visit(arg), annotator.compiler) new_args.append(ann) new_call = relay.Call(call.op, new_args, call.attrs, call.type_args) return compiler_end(new_call, annotator.compiler) else: return super().visit_call(call)
def visit_call(self, call): curr_last = self.last_call self.last_call = False params = [] for arg in call.args: param = super().visit(arg) if isinstance(param, relay.expr.Var): param = compiler_begin(param, self.compiler) params.append(param) new_call = relay.Call(call.op, params, call.attrs) if curr_last: new_call = compiler_end(new_call, self.compiler) return new_call
def visit_call(self, call): if call.op.name == 'nn.global_avg_pool2d': self.compiler_open = True compiler_open = self.compiler_open params = [] for arg in call.args: param = super().visit(arg) if call.op.name == 'nn.global_avg_pool2d': param = compiler_end(param, self.compiler) if compiler_open and isinstance(param, relay.expr.Var): param = compiler_begin(param, self.compiler) params.append(param) new_call = relay.Call(call.op, params, call.attrs) return new_call