def __enter__(self): self.scope = te.KernelScope()
import time import torch import torch.nn as nn import torch.nn.functional as F import torch._C._te as te import torch.fx as fx import torch.utils._pytree as pytree from torch.fx import map_arg from torch.fx.passes.shape_prop import ShapeProp import operator import functools scope = te.KernelScope() def truncate(model, k): model = fx.symbolic_trace(model) new_graph = fx.Graph() env = {} cnt = 0 for node in list(model.graph.nodes): new_node = new_graph.node_copy(node, lambda x: env[x.name]) env[node.name] = new_node cnt += 1 if cnt == k: new_graph.output(env[node.name]) break return fx.GraphModule(model, new_graph)