def extract(self, graph: Graph, *args): for i, n in enumerate(graph.nodes): if isinstance(n, ComputeNode) and n.has_backward: self.solver_to_graph.append(i) self.graph_to_solver[i] = len(self.solver_to_graph) - 1 snodeid = len(self.nodes) self.nodes[snodeid] = FwdNode(n) inbound = [] for (dep, grad) in n.dependencies: if isinstance(graph.nodes[dep], ComputeNode): assert grad == True inbound.append(self.graph_to_solver[dep]) self.nodes[snodeid].inbound_nodes = inbound if n.op == "aten::max_pool2d": op = lm_ops.list_ops(self.mode, n.op)[-1] storage = op.backward_storage if not isinstance(storage, list): storage = [storage] for storetype in storage: if isinstance(storetype, lm_ops.IntermediateStorage): self.nodes[snodeid].has_intermediates = True self.nodes[ snodeid].intermediate_size = storetype.size( n.shape) ibnode = self.nodes[snodeid].inbound_nodes[0] if "relu" in self.nodes[ibnode].op: op = lm_ops.list_ops(self.mode, self.nodes[ibnode].op)[-1] storage = op.backward_storage if not isinstance(storage, list): storage = [storage] for storetype in storage: if isinstance(storetype, lm_ops.IntermediateStorage): self.nodes[ibnode].has_intermediates = True self.nodes[ ibnode].intermediate_size = storetype.size( self.nodes[ibnode].gnode.shape) self.loss = len(self.nodes) nloss = ComputeNode(self.nodes[self.loss - 1].gnode.shape, -2, "loss::loss", [], False) self.nodes[self.loss] = FwdNode(nloss) self.nodes[self.loss].inbound_nodes = [self.loss - 1] self.extract_deps(graph) self.get_mem() self.get_local_memory() self.get_fixed_memory(graph) self.get_conv_info() self.get_inplace_info() self.get_workspace_cost(graph, *args) self.size = len(self.nodes) # Write solver_info import pickle si_file = open("../data/si_" + self.data_path + "_pipeline.pkl", 'wb') pickle.dump(self, si_file) si_file.close()
def get_conv_info(self): self.conv_list = defaultdict() convop = lm_ops.list_ops(self.mode, 'aten::_convolution')[0] self.num_conv_algos = 1 if self.select_conv_algo: self.num_conv_algos = max(convop.n_fwd_algos(), convop.n_bwd_ip_algos(), convop.n_bwd_wt_algos()) # self. = True for i in self.nodes: if "convolution" in self.nodes[i].op: self.conv_list[i] = len(self.conv_list)
def extract(self, graph: Graph, *args): # Create solver graph of forward and backward pass annotated with memory and compute info for i, n in enumerate(graph.nodes): # Create solver forward nodes if isinstance(n, ComputeNode) and n.has_backward: self.solver_to_graph.append(i) self.graph_to_solver[i] = len(self.solver_to_graph)-1 snodeid = len(self.nodes) self.nodes[snodeid] = FwdNode(n) inbound = [] for (dep, grad) in n.dependencies: if isinstance(graph.nodes[dep], ComputeNode): assert grad == True inbound.append(self.graph_to_solver[dep]) self.nodes[snodeid].inbound_nodes = inbound for idx, op in enumerate(lm_ops.list_ops(self.mode, n.op)): storage = op.backward_storage if not isinstance(storage, list): storage = [storage] for storetype in storage: if isinstance(storetype, lm_ops.IntermediateStorage): # Create intermediate node newnode_op = "int::" + op.__name__ ni = ComputeNode([storetype.size(n.shape)], -1, newnode_op, [], False) self.solver_to_graph.append(-1) sintnodeid = len(self.nodes) self.nodes[sintnodeid] = IntNode(ni, snodeid, idx) inbound = [] inbound.append(snodeid) self.nodes[sintnodeid].inbound_nodes = inbound self.newnodes.append(sintnodeid) self.nodes[snodeid].has_intermediates = True self.nodes[snodeid].intermediates.append((idx,sintnodeid)) self.loss = len(self.nodes) nloss = ComputeNode(self.nodes[self.loss-1].gnode.shape, -2, "loss::loss", [], False) self.nodes[self.loss] = FwdNode(nloss) self.nodes[self.loss].inbound_nodes = [self.loss-1] self.extract_deps(graph) self.get_mem() self.get_local_memory() self.get_fixed_memory(graph) self.get_conv_info() self.get_inplace_info() self.get_workspace_cost(graph, *args) self.size = len(self.nodes) # Write solver_info import pickle si_file = open("../data/si_" + self.data_path + ".pkl", 'wb') pickle.dump(self, si_file) si_file.close()
def get_inplace_info(self): self.inplace_list = defaultdict() if self.do_inplace: for v in range(self.loss): if isinstance(self.nodes[v], IntNode): continue if hasattr(lm_ops.list_ops(self.mode, self.nodes[v].op)[0],'inplace'): for u in self.nodes[v].args[0]: if self.nodes[u].last_used == v: self.inplace_list[v] = u if self.nodes[v].has_intermediates: for (_, intmd) in self.nodes[v].intermediates: self.inplace_list[intmd] = u
def get_conv_info(self): self.conv_list = defaultdict() convop = lm_ops.list_ops(self.mode, 'aten::_convolution')[0] self.num_conv_algos = 1 if self.select_conv_algo: self.num_conv_algos = max(convop.n_fwd_algos(), convop.n_bwd_ip_algos(), convop.n_bwd_wt_algos()) # self. = True for i in self.nodes: if isinstance(self.nodes[i], BwdNode): if "convolution" in self.nodes[i].op and not self.nodes[i].fwd_node.is_depthwise: self.conv_list[i] = len(self.conv_list) else: if "convolution" in self.nodes[i].op and not self.nodes[i].gnode.is_depthwise: self.conv_list[i] = len(self.conv_list)
def extract_deps(self, graph): # Create solver backward nodes for i in sorted(self.nodes, reverse=True): if i > self.loss: continue if self.nodes[i].gnode.has_backward: num_bwd = len(lm_ops.list_ops(self.mode, self.nodes[i].op)) nbwd = BwdNode(self.nodes[i].gnode, i, num_bwd) self.fwd_to_bwd[i] = len(self.nodes) self.bwd_to_fwd[len(self.nodes)] = i self.nodes[len(self.nodes)] = nbwd self.fwd_to_bwd[self.loss] = self.loss # get deps for backward for i in self.nodes: if isinstance(self.nodes[i], BwdNode): continue if self.nodes[i].gnode.has_backward: for ni in self.nodes[i].inbound_nodes: num_bwd_prev = len( lm_ops.list_ops(self.mode, self.nodes[ni].op)) for p in range(num_bwd_prev): self.nodes[self.fwd_to_bwd[ni]].dep_list_bwd[p].append( self.fwd_to_bwd[i]) ops = lm_ops.list_ops(self.mode, self.nodes[i].op) num_bwd = len(ops) for p in range(num_bwd): deps = ops[p]().backward_storage l = [] if not isinstance(deps, list): deps = [deps] for dep in deps: if isinstance(dep, lm_ops.InputStorage): for inids in dep.ids: arg_in = self.nodes[i].gnode.args[inids] if isinstance(arg_in, ComputeNode.D): innode = graph.nodes[arg_in.index] if isinstance(innode, ComputeNode): l.append( self.graph_to_solver[arg_in.index]) if isinstance(dep, lm_ops.OutputStorage): l.append(i) if isinstance(dep, lm_ops.IntermediateStorage): added_int = False assert self.nodes[i].has_intermediates for p_option, nint_id in self.nodes[ i].intermediates: nint = self.nodes[nint_id] assert isinstance(nint, IntNode) if p_option == p: l.append(nint_id) added_int = True assert added_int == True self.nodes[self.fwd_to_bwd[i]].dep_list_fwd[p] = l if i == self.loss - 1: # inject loss node assuming we are at output node for p in range(num_bwd): self.nodes[self.fwd_to_bwd[i]].dep_list_fwd[p].append( self.loss) for i in range(self.loss + 1): self.nodes[i].make_args() for i in range(self.loss + 1, len(self.nodes)): self.nodes[i].make_args() fwd_node = self.nodes[i].fwd_node stored = [] output_shapes = [] # Need to have a list because of nodes like add and cat # NOTE for both aten::add and cat, we consider output of bwd as both inputs of fwd for (dep, rgrad) in (fwd_node.dependencies): nin = graph.nodes[dep] if isinstance(nin, Param) and rgrad == True: stored.append( nin ) # Params have gradients which will be stored in backward elif isinstance(nin, Param): pass elif isinstance(nin, ComputeNode) and rgrad == True: output_shapes.append(list(nin.shape)) elif isinstance(nin, Input): if rgrad: output_shapes.append(list(nin.shape)) else: sys.exit("Unknown node encountered ") self.nodes[i].output_shapes = output_shapes self.nodes[i].stored = stored # Create edge_list self.edge_list = [] for v in self.nodes: for k, vdeps in enumerate(self.nodes[v].args): for u in vdeps: edge = (u, v, k) self.edge_list.append(edge) self.last_use_bwd = defaultdict(dict) for i in self.nodes: if isinstance(self.nodes[i], BwdNode): assert len(self.nodes[i].dep_list_fwd) == 1 for j in self.nodes[i].dep_list_fwd[0]: if isinstance(self.nodes[j], IntNode): pj = self.nodes[j].solver_parent_id self.last_use_bwd[self.solver_to_graph[pj]][ "int"] = self.solver_to_graph[self.bwd_to_fwd[i]] else: if j == self.loss: sj = graph._outputs[0] else: sj = self.solver_to_graph[j] b_graphi = self.solver_to_graph[self.bwd_to_fwd[i]] self.last_use_bwd[sj]["ip"] = b_graphi
def init_schedule(self, solution: Solution, mode): self.solution = solution T = len(self.si.nodes) - self.si.loss # Create main structures self._op = [[] for i in range(T)] # List of fwd operators self._bwd_op = [] # List of bwd operators self._fwd_schedule = [[] for i in range(T)] # Forward schedule # Initialize forward pass structures for t in range(T): for i, n in enumerate(self._nodes): if isinstance(n, ComputeNode) and n.op != "aten::t": j = self.si.graph_to_solver[i] p = 0 if self.si.compute_newnode and ( self.si.nodes[j].has_intermediates): if self.si.nodes[j].has_intermediates: for (p_option, intid) in self.si.nodes[j].intermediates: if solution.r[t][intid]: assert solution.r[t][j] p = p_option op = lm_ops.list_ops(self.si.mode, n.op)[p]() if self.si.select_conv_algo and j in self.si.conv_list: num_fwd_algos = op.n_fwd_algos() for c in range(num_fwd_algos): if solution.rf[t, self.si.conv_list[j], c]: op.algorithm = c break if n.is_depthwise: op.is_depthwise = True if self.si.do_inplace and j in self.si.inplace_list: if solution.ip[t][j]: op.inplace = True self._op[t].append(op) schedule_intermediate = False storage = op.backward_storage if not isinstance(storage, list): storage = [storage] for store in storage: if isinstance(store, lm_ops.IntermediateStorage): schedule_intermediate = True if t < T - 1: s = solution.s[t + 1][j] r = solution.r[t][j] self._fwd_schedule[t].append( ScheduleType(r, s, schedule_intermediate)) else: r = solution.r[t][j] self._fwd_schedule[t].append( ScheduleType(r, False, schedule_intermediate)) else: # Node represents a parameter self._op[t].append(None) self._fwd_schedule[t].append(None) # Initialize backward pass structures for t in range(T): bwd_t = self.si.loss + t if t != 0 and isinstance(self.si.nodes[bwd_t], BwdNode): n = self.nodes[self.si.solver_to_graph[self.si.bwd_to_fwd[ bwd_t]]] # Backportability to when si didn't support depthwise options = len(solution.m[t]) p = 0 for o in range(options): if solution.m[t][o] == 1: p = o op = lm_ops.list_ops(self.si.mode, n.op)[p]() if n.is_depthwise: op.is_depthwise = True if (n.op == "aten::_convolution" and not n.is_depthwise) or n.op == "aten::addmm": algo = 0 if self.si.nodes[bwd_t].bwd_op == "param_grad": algo_type = 0 else: algo_type = 1 if self.si.select_conv_algo and bwd_t in self.si.conv_list: algo = -1 num_algos = solution.rf.shape[2] for c in range(num_algos): if solution.rf[t, self.si.conv_list[bwd_t], c]: algo = c break if algo == -1: raise RuntiimeError("Algorithm not decided", t, bwd_t) algo = algo_type * 10 + algo op.algorithm = algo self._bwd_op.append(op) else: self._bwd_op.append(None)
def init_schedule(self, mode): T = len(self.si.nodes) - self.si.loss # Create main structures self._op = [[] for i in range(T)] # List of fwd operators self._bwd_op = [] # List of bwd operators self._fwd_schedule = [[] for i in range(T)] # Forward schedule t = 0 to_store_graph = [False for i in range(len(self._nodes))] for k, n in enumerate(self._nodes): if isinstance(n, ComputeNode) and n.op != "aten::t": p = 0 op = lm_ops.list_ops(self.si.mode, n.op)[p]() storage_list = op.backward_storage if n.op == "aten::gist__convolution" or n.op == "aten::nosave_relu_": continue # do not store the input nosave_relu because already storing its compressed version if not isinstance(storage_list, list): storage_list = [storage_list] for storage in storage_list: if isinstance(storage, lm_ops.InputStorage): for i in storage.ids: to_store_graph[n.args[i].index] = True elif isinstance(storage, lm_ops.OutputStorage): to_store_graph[k] = True for i, n in enumerate(self._nodes): if isinstance(n, ComputeNode) and n.op != "aten::t": j = self.si.graph_to_solver[i] p = 0 op = lm_ops.list_ops(self.si.mode, n.op)[p]() if "relu_" in n.op or "hardtanh_" in n.op or "add_" in n.op: op.inplace = True if n.is_depthwise: op.is_depthwise = True self._op[t].append(op) schedule_intermediate = False storage = op.backward_storage if not isinstance(storage, list): storage = [storage] for store in storage: if isinstance(store, lm_ops.IntermediateStorage): schedule_intermediate = True self._fwd_schedule[0].append( ScheduleType(True, to_store_graph[i], schedule_intermediate)) else: # Node represents a parameter self._op[t].append(None) self._fwd_schedule[t].append(None) # Initialize backward pass structures for t in range(T): bwd_t = self.si.loss + t if t != 0 and isinstance(self.si.nodes[bwd_t], BwdNode): n = self.nodes[self.si.solver_to_graph[self.si.bwd_to_fwd[ bwd_t]]] # Backportability to when si didn't support depthwise p = 0 op = lm_ops.list_ops(self.si.mode, n.op)[p]() if n.is_depthwise: op.is_depthwise = True self._bwd_op.append(op) else: self._bwd_op.append(None)
def init_schedule(self, solution: CheckmateSolution, mode): T = len(self.si.nodes) # Create main structures self._op = [None for i in range(self.lennodes)] # List of operations self._fwd_schedule = [[] for i in range(T)] # Forward schedule self._bwd_schedule = [[] for i in range(T)] # Backward schedule self.fwdargs = [None for i in range(self.lennodes) ] # Index to forward node input tensor self.bwdargs = [None for i in range(self.lennodes) ] # Index to backward node input tensors # Initialize forward pass structures for t in range(T): for i, n in enumerate(self.nodes): if isinstance(n, ComputeNode) and n.op != "aten::t": j = self.si.graph_to_solver[i] ops_list = lm_ops.list_ops(self.si.mode, n.op) if isinstance(self.si, PipelinedSolverInfo ) and self.si.nodes[j].has_intermediates: op = ops_list[-1]( ) # Select intermediate-computing and intermediate-activated operator implementation else: op = ops_list[0]( ) # Select the default operator implementations if n.is_depthwise: op.is_depthwise = True s = solution.s[t + 1][j] if t < T - 1 else False r = solution.r[t][j] f = solution.f[t][j] schedule_intermediate = False storage = op.backward_storage if not isinstance(storage, list): storage = [storage] for store in storage: if isinstance(store, lm_ops.IntermediateStorage): schedule_intermediate = True if r or len(f) or s: self._fwd_schedule[t].append( (i, ScheduleType(r, s, f, schedule_intermediate), n.op)) self._op[i] = op self.fwdargs[i] = [ (a.value, None) if isinstance(a, ComputeNode.V) else (a.index, a.requires_grad) for a in n.args ] elif isinstance(n, ComputeNode) and n.op == "aten::t": pass else: # Node represents a parameter self._fwd_schedule[t].append((i, None, None)) self._op[i] = None # Initialize backward pass structures for k, m in reversed(list(enumerate(self.nodes))): # Create backward dependencies if isinstance(m, ComputeNode) and m.op != "aten::t": j = self.si.fwd_to_bwd[self.si.graph_to_solver[k]] n = self.si.nodes[j] assert isinstance(n, BwdNode) self.bwdargs[k] = {'param': [], 'ip': []} storage_list = self._op[k].backward_storage if not isinstance(storage_list, list): storage_list = [storage_list] for storage in storage_list: if isinstance(storage, lm_ops.InputStorage): for posi, i in enumerate(storage.ids): idx = m.args[i].index if (((m.op == "aten::_convolution" and not m.is_depthwise) or m.op == "aten::addmm") and n.bwd_op == "ip_grad"): self.bwdargs[k]['param'].append( (idx, True, False)) if posi == 0: self.bwdargs[k]['ip'].append( (idx, False, False) ) # Input tensor for conv/addmm ip grad need not be stored else: self.bwdargs[k]['ip'].append( (idx, True, False)) else: self.bwdargs[k]['ip'].append( (idx, True, False)) elif isinstance(storage, lm_ops.OutputStorage): self.bwdargs[k]['ip'].append((k, True, False)) elif isinstance(storage, lm_ops.IntermediateStorage): self.bwdargs[k]['ip'].append((k, True, True)) # Create backward schedule for t in range(T): if isinstance(m, ComputeNode) and m.op != "aten::t": j = self.si.fwd_to_bwd[self.si.graph_to_solver[k]] n = self.si.nodes[j] assert isinstance(n, BwdNode) s = solution.s[t + 1][j] if t < T - 1 else False r = solution.r[t][j] f = solution.f[t][j] if (((m.op == "aten::_convolution" and not m.is_depthwise) or m.op == "aten::addmm") and n.bwd_op == "ip_grad"): s1 = solution.s[t + 1][j - 1] if t < T - 1 else False if solution.r[t][j - 1] or len( solution.f[t][j - 1]) or s1: self._bwd_schedule[t].append( (k, ScheduleType(solution.r[t][j - 1], s1, solution.f[t][j - 1], False), "param")) if r or len(f) or s: self._bwd_schedule[t].append( (k, ScheduleType(r, s, f, False), "ip")) elif isinstance(m, ComputeNode) and m.op == "aten::t": pass else: self._bwd_schedule[t].append((k, None, "grad")) self.opshapes = defaultdict() for k in self._outputs: self.opshapes[k] = [ self.bs if dim == -1 else dim for dim in self._nodes[k].shape ]
def extract_deps(self, graph): # create bwd nodes for i in sorted(self.nodes, reverse=True): if i > self.loss: continue if self.nodes[i].gnode.has_backward: # print(nodes[i].op) num_bwd = 1 if self.nodes[i].op == "aten::addmm" or self.nodes[ i].op == "aten::_convolution": # First param_grad, then ip_grad # fwd_to_bwd points to ip_grad nbwd1 = BwdNode(self.nodes[i].gnode, i, num_bwd, "param_grad") nbwd2 = BwdNode(self.nodes[i].gnode, i, num_bwd, "ip_grad") self.bwd_to_fwd[len(self.nodes)] = i self.nodes[len(self.nodes)] = nbwd1 self.fwd_to_bwd[i] = len(self.nodes) self.bwd_to_fwd[len(self.nodes)] = i self.nodes[len(self.nodes)] = nbwd2 else: nbwd = BwdNode(self.nodes[i].gnode, i, num_bwd) if self.nodes[i].has_intermediates: nbwd.has_intermediates = True self.fwd_to_bwd[i] = len(self.nodes) self.bwd_to_fwd[len(self.nodes)] = i self.nodes[len(self.nodes)] = nbwd self.fwd_to_bwd[self.loss] = self.loss # get deps for backward for i in self.nodes: if isinstance(self.nodes[i], BwdNode): continue if self.nodes[i].gnode.has_backward: for ni in self.nodes[i].inbound_nodes: num_bwd_prev = 1 for p in range(num_bwd_prev): self.nodes[self.fwd_to_bwd[ni]].dep_list_bwd[p].append( self.fwd_to_bwd[i]) if self.nodes[ni].op == "aten::addmm" or self.nodes[ ni].op == "aten::_convolution": # Add bwd deps to param_grad too self.nodes[self.fwd_to_bwd[ni] - 1].dep_list_bwd[p].append( self.fwd_to_bwd[i]) ops_all = lm_ops.list_ops(self.mode, self.nodes[i].op) if self.nodes[i].has_intermediates: ops = [ops_all[-1]] else: ops = [ops_all[0]] num_bwd = len(ops) if self.nodes[i].op == "aten::addmm" or self.nodes[ i].op == "aten::_convolution": assert len(self.nodes[i].inbound_nodes) <= 1 if len(self.nodes[i].inbound_nodes) == 1: for p in range(num_bwd): self.nodes[self.fwd_to_bwd[i] - 1].dep_list_fwd[p].append( self.nodes[i].inbound_nodes[0]) else: for p in range(num_bwd): deps = ops[p]().backward_storage l = [] if not isinstance(deps, list): deps = [deps] for dep in deps: if isinstance(dep, lm_ops.InputStorage): for inids in dep.ids: arg_in = self.nodes[i].gnode.args[inids] if isinstance(arg_in, ComputeNode.D): innode = graph.nodes[arg_in.index] if isinstance(innode, ComputeNode): l.append(self.graph_to_solver[ arg_in.index]) if isinstance(dep, lm_ops.OutputStorage): l.append(i) if isinstance(dep, lm_ops.IntermediateStorage): added_int = False assert self.nodes[i].has_intermediates # No fwd dependency # for p_option, nint_id in self.nodes[i].intermediates: # nint = self.nodes[nint_id] # assert isinstance(nint, IntNode) # if p_option == p: # l.append(nint_id) # added_int = True # assert added_int == True self.nodes[self.fwd_to_bwd[i]].dep_list_fwd[p] = l if i == self.loss - 1: # inject loss node assuming we are at output node for p in range(num_bwd): self.nodes[self.fwd_to_bwd[i]].dep_list_fwd[p].append( self.loss) if self.nodes[i].op == "aten::addmm" or self.nodes[ i].op == "aten::_convolution": self.nodes[self.fwd_to_bwd[i] - 1].dep_list_fwd[p].append(self.loss) for i in range(self.loss + 1): self.nodes[i].make_args() for i in range(self.loss + 1, len(self.nodes)): self.nodes[i].make_args() fwd_node = self.nodes[i].fwd_node stored = [] output_shapes = [] # Need to have a list because of nodes like add and cat # NOTE for both aten::add and cat, we consider output of bwd as both inputs of fwd for (dep, rgrad) in (fwd_node.dependencies): nin = graph.nodes[dep] if isinstance(nin, Param) and rgrad == True: stored.append( nin ) # Params have gradients which will be stored in backward elif isinstance(nin, Param): pass elif isinstance(nin, ComputeNode) and rgrad == True: output_shapes.append(list(nin.shape)) elif isinstance(nin, Input): if rgrad: output_shapes.append(list(nin.shape)) else: sys.exit("Unknown node encountered ") if (fwd_node.op == "aten::_convolution" or fwd_node.op == "aten::addmm") and self.nodes[i].bwd_op == "ip_grad": stored = [] if (fwd_node.op == "aten::_convolution" or fwd_node.op == "aten::addmm") and self.nodes[i].bwd_op == "param_grad": output_shapes = [] self.nodes[i].output_shapes = output_shapes self.nodes[i].stored = stored # Create edge_list self.edge_list = [] for v in self.nodes: for k, vdeps in enumerate(self.nodes[v].args): for u in vdeps: edge = (u, v, k) self.edge_list.append(edge)
def get_workspace_cost(self, graph, *args): from pathlib import Path import pickle workspace_mem = defaultdict(list) workspace_compute = defaultdict(list) recompute_workspace_mem = defaultdict(list) recompute_workspace_compute = defaultdict(list) inplace_workspace_mem = defaultdict(list) inplace_workspace_compute = defaultdict(list) cost = Path("../data/cost_" + self.data_path + "_pipeline.pkl") if cost.is_file(): pkl_cost = open("../data/cost_" + self.data_path + "_pipeline.pkl", 'rb') workspace_mem, workspace_compute, recompute_workspace_mem, recompute_workspace_compute, inplace_workspace_mem, inplace_workspace_compute = pickle.load( pkl_cost) pkl_cost.close() assert len(workspace_compute) == len(self.nodes) assert len(workspace_mem) == len(self.nodes) for i in self.nodes: self.nodes[i].workspace_mem = workspace_mem[i] self.nodes[i].workspace_compute = workspace_compute[i] self.nodes[ i].recompute_workspace_mem = recompute_workspace_mem[i] self.nodes[ i].recompute_workspace_compute = recompute_workspace_compute[ i] self.nodes[i].inplace_workspace_mem = inplace_workspace_mem[i] self.nodes[ i].inplace_workspace_compute = inplace_workspace_compute[i] # Recomputation memory else: for i in sorted(self.nodes.keys(), reverse=True): n = self.nodes[i] if isinstance(n, BwdNode): b = n n = b.fwd_node op_impls_all = lm_ops.list_ops(self.mode, n.op) if self.nodes[i].has_intermediates: op_impls = [op_impls_all[-1]] else: op_impls = [op_impls_all[0]] for op_impl in op_impls: fwd_working_memory, bwd_working_memory, fwd_working_memory_recompute = meminfo( n, op_impl(), graph, self.bs, b.bwd_op, self.select_conv_algo, self.do_inplace, *args) runtime_fwd, runtime_bwd, runtime_fwd_recompute = computeinfo( n, op_impl(), graph, self.bs, b.bwd_op, self.select_conv_algo, self.do_inplace, *args) print(b, fwd_working_memory, bwd_working_memory, fwd_working_memory_recompute, runtime_fwd, runtime_bwd, runtime_fwd_recompute) if self.select_conv_algo and self.bwd_to_fwd[ i] in self.conv_list: for cbwd in range(len(bwd_working_memory)): self.nodes[i].workspace_mem.append( bwd_working_memory[cbwd]) self.nodes[i].workspace_compute.append( runtime_bwd[cbwd]) if not b.bwd_op == "ip_grad": for cfwd in range(len(fwd_working_memory)): self.nodes[self.bwd_to_fwd[ i]].workspace_mem.append( fwd_working_memory[cfwd]) self.nodes[self.bwd_to_fwd[ i]].workspace_compute.append( runtime_fwd[cfwd]) elif self.do_inplace and self.bwd_to_fwd[ i] in self.inplace_list: self.nodes[i].workspace_mem.append( bwd_working_memory[0]) self.nodes[i].workspace_compute.append( runtime_bwd[0]) self.nodes[ self.bwd_to_fwd[i]].workspace_mem.append( fwd_working_memory[0]) self.nodes[ self.bwd_to_fwd[i]].workspace_compute.append( runtime_fwd[0]) self.nodes[self.bwd_to_fwd[ i]].inplace_workspace_mem.append( fwd_working_memory[1]) self.nodes[self.bwd_to_fwd[ i]].inplace_workspace_compute.append( runtime_fwd[1]) else: self.nodes[i].workspace_mem.append( bwd_working_memory[0]) self.nodes[i].workspace_compute.append( runtime_bwd[0]) if not ((n.op == "aten::_convolution" or n.op == "aten::addmm") and b.bwd_op == "ip_grad"): if n.op == "aten::batch_norm": self.nodes[self.bwd_to_fwd[ i]].recompute_workspace_mem.append( fwd_working_memory_recompute[0]) self.nodes[self.bwd_to_fwd[ i]].recompute_workspace_compute.append( runtime_fwd_recompute[0]) self.nodes[ self.bwd_to_fwd[i]].workspace_mem.append( fwd_working_memory[0]) self.nodes[self.bwd_to_fwd[ i]].workspace_compute.append( runtime_fwd[0]) else: opname = n.op if "int::" in opname: assert isinstance(n, IntNode) parent_node = self.nodes[n.solver_parent_id] p = n.op_idx self.nodes[i].workspace_mem.append( parent_node.workspace_mem[p]) self.nodes[i].workspace_compute.append( parent_node.workspace_compute[p] - parent_node.workspace_compute[0]) if self.do_inplace and n.solver_parent_id in self.inplace_list: self.nodes[i].inplace_workspace_mem.append( parent_node.inplace_workspace_mem[p]) self.nodes[i].inplace_workspace_compute.append( parent_node.inplace_workspace_compute[p] - parent_node.inplace_workspace_compute[0]) elif "loss::" in opname: # For now assuming loss calculation is operation-less self.nodes[i].workspace_compute.append(0) self.nodes[i].workspace_mem.append(0) for i in self.nodes: if len(self.nodes[i].recompute_workspace_mem) == 0: self.nodes[i].recompute_workspace_mem = self.nodes[ i].workspace_mem self.nodes[i].recompute_workspace_compute = self.nodes[ i].workspace_compute if len(self.nodes[i].inplace_workspace_mem) == 0: self.nodes[i].inplace_workspace_mem = self.nodes[ i].workspace_mem self.nodes[i].inplace_workspace_compute = self.nodes[ i].workspace_compute for i in self.nodes: workspace_compute[i] = self.nodes[i].workspace_compute workspace_mem[i] = self.nodes[i].workspace_mem recompute_workspace_compute[i] = self.nodes[ i].recompute_workspace_compute recompute_workspace_mem[i] = self.nodes[ i].recompute_workspace_mem inplace_workspace_compute[i] = self.nodes[ i].inplace_workspace_compute inplace_workspace_mem[i] = self.nodes[i].inplace_workspace_mem cost_file = open( "../data/cost_" + self.data_path + "_pipeline.pkl", 'wb') pickle.dump([ workspace_mem, workspace_compute, recompute_workspace_mem, recompute_workspace_compute, inplace_workspace_mem, inplace_workspace_compute ], cost_file) cost_file.close()