Beispiel #1
0
    def extract(self, graph: Graph, *args):
        for i, n in enumerate(graph.nodes):
            if isinstance(n, ComputeNode) and n.has_backward:
                self.solver_to_graph.append(i)
                self.graph_to_solver[i] = len(self.solver_to_graph) - 1
                snodeid = len(self.nodes)
                self.nodes[snodeid] = FwdNode(n)
                inbound = []
                for (dep, grad) in n.dependencies:
                    if isinstance(graph.nodes[dep], ComputeNode):
                        assert grad == True
                        inbound.append(self.graph_to_solver[dep])
                self.nodes[snodeid].inbound_nodes = inbound

                if n.op == "aten::max_pool2d":
                    op = lm_ops.list_ops(self.mode, n.op)[-1]
                    storage = op.backward_storage
                    if not isinstance(storage, list):
                        storage = [storage]
                    for storetype in storage:
                        if isinstance(storetype, lm_ops.IntermediateStorage):
                            self.nodes[snodeid].has_intermediates = True
                            self.nodes[
                                snodeid].intermediate_size = storetype.size(
                                    n.shape)
                    ibnode = self.nodes[snodeid].inbound_nodes[0]
                    if "relu" in self.nodes[ibnode].op:
                        op = lm_ops.list_ops(self.mode,
                                             self.nodes[ibnode].op)[-1]
                        storage = op.backward_storage
                        if not isinstance(storage, list):
                            storage = [storage]
                        for storetype in storage:
                            if isinstance(storetype,
                                          lm_ops.IntermediateStorage):
                                self.nodes[ibnode].has_intermediates = True
                                self.nodes[
                                    ibnode].intermediate_size = storetype.size(
                                        self.nodes[ibnode].gnode.shape)

        self.loss = len(self.nodes)
        nloss = ComputeNode(self.nodes[self.loss - 1].gnode.shape, -2,
                            "loss::loss", [], False)
        self.nodes[self.loss] = FwdNode(nloss)
        self.nodes[self.loss].inbound_nodes = [self.loss - 1]
        self.extract_deps(graph)
        self.get_mem()
        self.get_local_memory()
        self.get_fixed_memory(graph)
        self.get_conv_info()
        self.get_inplace_info()
        self.get_workspace_cost(graph, *args)
        self.size = len(self.nodes)
        # Write solver_info
        import pickle
        si_file = open("../data/si_" + self.data_path + "_pipeline.pkl", 'wb')
        pickle.dump(self, si_file)
        si_file.close()
Beispiel #2
0
 def get_conv_info(self):
     self.conv_list = defaultdict()
     convop = lm_ops.list_ops(self.mode, 'aten::_convolution')[0]
     self.num_conv_algos = 1
     if self.select_conv_algo:
         self.num_conv_algos = max(convop.n_fwd_algos(),
                                   convop.n_bwd_ip_algos(),
                                   convop.n_bwd_wt_algos())
         # self. = True
         for i in self.nodes:
             if "convolution" in self.nodes[i].op:
                 self.conv_list[i] = len(self.conv_list)
Beispiel #3
0
    def extract(self, graph: Graph, *args):
        # Create solver graph of forward and backward pass annotated with memory and compute info
        for i, n in enumerate(graph.nodes):
            # Create solver forward nodes
            if isinstance(n, ComputeNode) and n.has_backward:
                self.solver_to_graph.append(i)
                self.graph_to_solver[i] = len(self.solver_to_graph)-1
                snodeid = len(self.nodes)
                self.nodes[snodeid] = FwdNode(n)
                inbound = []
                for (dep, grad) in n.dependencies:
                    if isinstance(graph.nodes[dep], ComputeNode):
                        assert grad == True
                        inbound.append(self.graph_to_solver[dep])
                self.nodes[snodeid].inbound_nodes = inbound

                for idx, op in enumerate(lm_ops.list_ops(self.mode, n.op)):
                    storage = op.backward_storage
                    if not isinstance(storage, list):
                        storage = [storage]
                    for storetype in storage:
                        if isinstance(storetype, lm_ops.IntermediateStorage):
                            # Create intermediate node
                            newnode_op = "int::" + op.__name__
                            ni = ComputeNode([storetype.size(n.shape)], -1, newnode_op, [], False)
                            self.solver_to_graph.append(-1)
                            sintnodeid = len(self.nodes)
                            self.nodes[sintnodeid] = IntNode(ni, snodeid, idx)
                            inbound = []
                            inbound.append(snodeid)
                            self.nodes[sintnodeid].inbound_nodes = inbound
                            self.newnodes.append(sintnodeid)
                            self.nodes[snodeid].has_intermediates = True
                            self.nodes[snodeid].intermediates.append((idx,sintnodeid))

        self.loss = len(self.nodes)
        nloss = ComputeNode(self.nodes[self.loss-1].gnode.shape, -2, "loss::loss", [], False)
        self.nodes[self.loss] = FwdNode(nloss)
        self.nodes[self.loss].inbound_nodes = [self.loss-1]
        self.extract_deps(graph)
        self.get_mem()
        self.get_local_memory()
        self.get_fixed_memory(graph)
        self.get_conv_info()
        self.get_inplace_info()
        self.get_workspace_cost(graph, *args)
        self.size = len(self.nodes)
        # Write solver_info
        import pickle
        si_file = open("../data/si_" + self.data_path + ".pkl", 'wb')
        pickle.dump(self, si_file)
        si_file.close()
Beispiel #4
0
 def get_inplace_info(self):
     self.inplace_list = defaultdict()
     if self.do_inplace:
         for v in range(self.loss):
             if isinstance(self.nodes[v], IntNode):
                 continue
             if hasattr(lm_ops.list_ops(self.mode, self.nodes[v].op)[0],'inplace'):
                 for u in self.nodes[v].args[0]:
                     if self.nodes[u].last_used == v:
                         self.inplace_list[v] = u
                         if self.nodes[v].has_intermediates:
                             for (_, intmd) in self.nodes[v].intermediates:
                                 self.inplace_list[intmd] = u
Beispiel #5
0
 def get_conv_info(self):
     self.conv_list = defaultdict()
     convop = lm_ops.list_ops(self.mode, 'aten::_convolution')[0]
     self.num_conv_algos = 1
     if self.select_conv_algo:
         self.num_conv_algos = max(convop.n_fwd_algos(), convop.n_bwd_ip_algos(), convop.n_bwd_wt_algos())
         # self. = True
         for i in self.nodes:
             if isinstance(self.nodes[i], BwdNode):
                 if "convolution" in self.nodes[i].op and not self.nodes[i].fwd_node.is_depthwise:
                     self.conv_list[i] = len(self.conv_list)
             else:
                 if "convolution" in self.nodes[i].op and not self.nodes[i].gnode.is_depthwise:
                     self.conv_list[i] = len(self.conv_list)
Beispiel #6
0
    def extract_deps(self, graph):
        # Create solver backward nodes
        for i in sorted(self.nodes, reverse=True):
            if i > self.loss:
                continue
            if self.nodes[i].gnode.has_backward:
                num_bwd = len(lm_ops.list_ops(self.mode, self.nodes[i].op))
                nbwd = BwdNode(self.nodes[i].gnode, i, num_bwd)
                self.fwd_to_bwd[i] = len(self.nodes)
                self.bwd_to_fwd[len(self.nodes)] = i
                self.nodes[len(self.nodes)] = nbwd
        self.fwd_to_bwd[self.loss] = self.loss

        # get deps for backward
        for i in self.nodes:
            if isinstance(self.nodes[i], BwdNode):
                continue
            if self.nodes[i].gnode.has_backward:
                for ni in self.nodes[i].inbound_nodes:
                    num_bwd_prev = len(
                        lm_ops.list_ops(self.mode, self.nodes[ni].op))
                    for p in range(num_bwd_prev):
                        self.nodes[self.fwd_to_bwd[ni]].dep_list_bwd[p].append(
                            self.fwd_to_bwd[i])

                ops = lm_ops.list_ops(self.mode, self.nodes[i].op)
                num_bwd = len(ops)

                for p in range(num_bwd):
                    deps = ops[p]().backward_storage
                    l = []
                    if not isinstance(deps, list):
                        deps = [deps]
                    for dep in deps:
                        if isinstance(dep, lm_ops.InputStorage):
                            for inids in dep.ids:
                                arg_in = self.nodes[i].gnode.args[inids]
                                if isinstance(arg_in, ComputeNode.D):
                                    innode = graph.nodes[arg_in.index]
                                    if isinstance(innode, ComputeNode):
                                        l.append(
                                            self.graph_to_solver[arg_in.index])
                        if isinstance(dep, lm_ops.OutputStorage):
                            l.append(i)
                        if isinstance(dep, lm_ops.IntermediateStorage):
                            added_int = False
                            assert self.nodes[i].has_intermediates
                            for p_option, nint_id in self.nodes[
                                    i].intermediates:
                                nint = self.nodes[nint_id]
                                assert isinstance(nint, IntNode)
                                if p_option == p:
                                    l.append(nint_id)
                                    added_int = True
                            assert added_int == True
                    self.nodes[self.fwd_to_bwd[i]].dep_list_fwd[p] = l

            if i == self.loss - 1:  # inject loss node assuming we are at output node
                for p in range(num_bwd):
                    self.nodes[self.fwd_to_bwd[i]].dep_list_fwd[p].append(
                        self.loss)

        for i in range(self.loss + 1):
            self.nodes[i].make_args()
        for i in range(self.loss + 1, len(self.nodes)):
            self.nodes[i].make_args()
            fwd_node = self.nodes[i].fwd_node
            stored = []
            output_shapes = []
            # Need to have a list because of nodes like add and cat
            # NOTE for both aten::add and cat, we consider output of bwd as both inputs of fwd
            for (dep, rgrad) in (fwd_node.dependencies):
                nin = graph.nodes[dep]
                if isinstance(nin, Param) and rgrad == True:
                    stored.append(
                        nin
                    )  # Params have gradients which will be stored in backward
                elif isinstance(nin, Param):
                    pass
                elif isinstance(nin, ComputeNode) and rgrad == True:
                    output_shapes.append(list(nin.shape))
                elif isinstance(nin, Input):
                    if rgrad:
                        output_shapes.append(list(nin.shape))
                else:
                    sys.exit("Unknown node encountered ")

            self.nodes[i].output_shapes = output_shapes
            self.nodes[i].stored = stored

        # Create edge_list
        self.edge_list = []
        for v in self.nodes:
            for k, vdeps in enumerate(self.nodes[v].args):
                for u in vdeps:
                    edge = (u, v, k)
                    self.edge_list.append(edge)

        self.last_use_bwd = defaultdict(dict)
        for i in self.nodes:
            if isinstance(self.nodes[i], BwdNode):
                assert len(self.nodes[i].dep_list_fwd) == 1
                for j in self.nodes[i].dep_list_fwd[0]:
                    if isinstance(self.nodes[j], IntNode):
                        pj = self.nodes[j].solver_parent_id
                        self.last_use_bwd[self.solver_to_graph[pj]][
                            "int"] = self.solver_to_graph[self.bwd_to_fwd[i]]
                    else:
                        if j == self.loss:
                            sj = graph._outputs[0]
                        else:
                            sj = self.solver_to_graph[j]
                        b_graphi = self.solver_to_graph[self.bwd_to_fwd[i]]
                        self.last_use_bwd[sj]["ip"] = b_graphi
Beispiel #7
0
    def init_schedule(self, solution: Solution, mode):
        self.solution = solution
        T = len(self.si.nodes) - self.si.loss
        # Create main structures
        self._op = [[] for i in range(T)]  # List of fwd operators
        self._bwd_op = []  # List of bwd operators
        self._fwd_schedule = [[] for i in range(T)]  # Forward schedule

        # Initialize forward pass structures
        for t in range(T):
            for i, n in enumerate(self._nodes):
                if isinstance(n, ComputeNode) and n.op != "aten::t":
                    j = self.si.graph_to_solver[i]
                    p = 0
                    if self.si.compute_newnode and (
                            self.si.nodes[j].has_intermediates):
                        if self.si.nodes[j].has_intermediates:
                            for (p_option,
                                 intid) in self.si.nodes[j].intermediates:
                                if solution.r[t][intid]:
                                    assert solution.r[t][j]
                                    p = p_option
                    op = lm_ops.list_ops(self.si.mode, n.op)[p]()
                    if self.si.select_conv_algo and j in self.si.conv_list:
                        num_fwd_algos = op.n_fwd_algos()
                        for c in range(num_fwd_algos):
                            if solution.rf[t, self.si.conv_list[j], c]:
                                op.algorithm = c
                                break
                    if n.is_depthwise:
                        op.is_depthwise = True
                    if self.si.do_inplace and j in self.si.inplace_list:
                        if solution.ip[t][j]:
                            op.inplace = True
                    self._op[t].append(op)
                    schedule_intermediate = False
                    storage = op.backward_storage
                    if not isinstance(storage, list):
                        storage = [storage]
                    for store in storage:
                        if isinstance(store, lm_ops.IntermediateStorage):
                            schedule_intermediate = True
                    if t < T - 1:
                        s = solution.s[t + 1][j]
                        r = solution.r[t][j]
                        self._fwd_schedule[t].append(
                            ScheduleType(r, s, schedule_intermediate))
                    else:
                        r = solution.r[t][j]
                        self._fwd_schedule[t].append(
                            ScheduleType(r, False, schedule_intermediate))
                else:
                    # Node represents a parameter
                    self._op[t].append(None)
                    self._fwd_schedule[t].append(None)

        # Initialize backward pass structures
        for t in range(T):
            bwd_t = self.si.loss + t
            if t != 0 and isinstance(self.si.nodes[bwd_t], BwdNode):
                n = self.nodes[self.si.solver_to_graph[self.si.bwd_to_fwd[
                    bwd_t]]]  # Backportability to when si didn't support depthwise
                options = len(solution.m[t])
                p = 0
                for o in range(options):
                    if solution.m[t][o] == 1:
                        p = o
                op = lm_ops.list_ops(self.si.mode, n.op)[p]()
                if n.is_depthwise:
                    op.is_depthwise = True
                if (n.op == "aten::_convolution"
                        and not n.is_depthwise) or n.op == "aten::addmm":
                    algo = 0
                    if self.si.nodes[bwd_t].bwd_op == "param_grad":
                        algo_type = 0
                    else:
                        algo_type = 1
                    if self.si.select_conv_algo and bwd_t in self.si.conv_list:
                        algo = -1
                        num_algos = solution.rf.shape[2]
                        for c in range(num_algos):
                            if solution.rf[t, self.si.conv_list[bwd_t], c]:
                                algo = c
                                break
                        if algo == -1:
                            raise RuntiimeError("Algorithm not decided", t,
                                                bwd_t)
                    algo = algo_type * 10 + algo
                    op.algorithm = algo
                self._bwd_op.append(op)
            else:
                self._bwd_op.append(None)
Beispiel #8
0
    def init_schedule(self, mode):
        T = len(self.si.nodes) - self.si.loss
        # Create main structures
        self._op = [[] for i in range(T)]  # List of fwd operators
        self._bwd_op = []  # List of bwd operators
        self._fwd_schedule = [[] for i in range(T)]  # Forward schedule

        t = 0

        to_store_graph = [False for i in range(len(self._nodes))]
        for k, n in enumerate(self._nodes):
            if isinstance(n, ComputeNode) and n.op != "aten::t":
                p = 0
                op = lm_ops.list_ops(self.si.mode, n.op)[p]()
                storage_list = op.backward_storage
                if n.op == "aten::gist__convolution" or n.op == "aten::nosave_relu_":
                    continue  # do not store the input nosave_relu because already storing its compressed version
                if not isinstance(storage_list, list):
                    storage_list = [storage_list]
                for storage in storage_list:
                    if isinstance(storage, lm_ops.InputStorage):
                        for i in storage.ids:
                            to_store_graph[n.args[i].index] = True
                    elif isinstance(storage, lm_ops.OutputStorage):
                        to_store_graph[k] = True

        for i, n in enumerate(self._nodes):
            if isinstance(n, ComputeNode) and n.op != "aten::t":
                j = self.si.graph_to_solver[i]
                p = 0
                op = lm_ops.list_ops(self.si.mode, n.op)[p]()
                if "relu_" in n.op or "hardtanh_" in n.op or "add_" in n.op:
                    op.inplace = True
                if n.is_depthwise:
                    op.is_depthwise = True
                self._op[t].append(op)
                schedule_intermediate = False
                storage = op.backward_storage
                if not isinstance(storage, list):
                    storage = [storage]
                for store in storage:
                    if isinstance(store, lm_ops.IntermediateStorage):
                        schedule_intermediate = True
                self._fwd_schedule[0].append(
                    ScheduleType(True, to_store_graph[i],
                                 schedule_intermediate))
            else:
                # Node represents a parameter
                self._op[t].append(None)
                self._fwd_schedule[t].append(None)

        # Initialize backward pass structures
        for t in range(T):
            bwd_t = self.si.loss + t
            if t != 0 and isinstance(self.si.nodes[bwd_t], BwdNode):
                n = self.nodes[self.si.solver_to_graph[self.si.bwd_to_fwd[
                    bwd_t]]]  # Backportability to when si didn't support depthwise
                p = 0
                op = lm_ops.list_ops(self.si.mode, n.op)[p]()
                if n.is_depthwise:
                    op.is_depthwise = True
                self._bwd_op.append(op)
            else:
                self._bwd_op.append(None)
Beispiel #9
0
    def init_schedule(self, solution: CheckmateSolution, mode):
        T = len(self.si.nodes)
        # Create main structures
        self._op = [None for i in range(self.lennodes)]  # List of operations
        self._fwd_schedule = [[] for i in range(T)]  # Forward schedule
        self._bwd_schedule = [[] for i in range(T)]  # Backward schedule
        self.fwdargs = [None for i in range(self.lennodes)
                        ]  # Index to forward node input tensor
        self.bwdargs = [None for i in range(self.lennodes)
                        ]  # Index to backward node input tensors

        # Initialize forward pass structures
        for t in range(T):
            for i, n in enumerate(self.nodes):
                if isinstance(n, ComputeNode) and n.op != "aten::t":
                    j = self.si.graph_to_solver[i]
                    ops_list = lm_ops.list_ops(self.si.mode, n.op)
                    if isinstance(self.si, PipelinedSolverInfo
                                  ) and self.si.nodes[j].has_intermediates:
                        op = ops_list[-1](
                        )  # Select intermediate-computing and intermediate-activated operator implementation
                    else:
                        op = ops_list[0](
                        )  # Select the default operator implementations
                    if n.is_depthwise:
                        op.is_depthwise = True
                    s = solution.s[t + 1][j] if t < T - 1 else False
                    r = solution.r[t][j]
                    f = solution.f[t][j]
                    schedule_intermediate = False
                    storage = op.backward_storage
                    if not isinstance(storage, list):
                        storage = [storage]
                    for store in storage:
                        if isinstance(store, lm_ops.IntermediateStorage):
                            schedule_intermediate = True
                    if r or len(f) or s:
                        self._fwd_schedule[t].append(
                            (i, ScheduleType(r, s, f,
                                             schedule_intermediate), n.op))
                        self._op[i] = op
                        self.fwdargs[i] = [
                            (a.value,
                             None) if isinstance(a, ComputeNode.V) else
                            (a.index, a.requires_grad) for a in n.args
                        ]
                elif isinstance(n, ComputeNode) and n.op == "aten::t":
                    pass
                else:
                    # Node represents a parameter
                    self._fwd_schedule[t].append((i, None, None))
                    self._op[i] = None

        # Initialize backward pass structures
        for k, m in reversed(list(enumerate(self.nodes))):
            # Create backward dependencies
            if isinstance(m, ComputeNode) and m.op != "aten::t":
                j = self.si.fwd_to_bwd[self.si.graph_to_solver[k]]
                n = self.si.nodes[j]
                assert isinstance(n, BwdNode)
                self.bwdargs[k] = {'param': [], 'ip': []}
                storage_list = self._op[k].backward_storage
                if not isinstance(storage_list, list):
                    storage_list = [storage_list]
                for storage in storage_list:
                    if isinstance(storage, lm_ops.InputStorage):
                        for posi, i in enumerate(storage.ids):
                            idx = m.args[i].index
                            if (((m.op == "aten::_convolution" and
                                  not m.is_depthwise) or m.op == "aten::addmm")
                                    and n.bwd_op == "ip_grad"):
                                self.bwdargs[k]['param'].append(
                                    (idx, True, False))
                                if posi == 0:
                                    self.bwdargs[k]['ip'].append(
                                        (idx, False, False)
                                    )  # Input tensor for conv/addmm ip grad need not be stored
                                else:
                                    self.bwdargs[k]['ip'].append(
                                        (idx, True, False))
                            else:
                                self.bwdargs[k]['ip'].append(
                                    (idx, True, False))
                    elif isinstance(storage, lm_ops.OutputStorage):
                        self.bwdargs[k]['ip'].append((k, True, False))
                    elif isinstance(storage, lm_ops.IntermediateStorage):
                        self.bwdargs[k]['ip'].append((k, True, True))

            # Create backward schedule
            for t in range(T):
                if isinstance(m, ComputeNode) and m.op != "aten::t":
                    j = self.si.fwd_to_bwd[self.si.graph_to_solver[k]]
                    n = self.si.nodes[j]
                    assert isinstance(n, BwdNode)
                    s = solution.s[t + 1][j] if t < T - 1 else False
                    r = solution.r[t][j]
                    f = solution.f[t][j]
                    if (((m.op == "aten::_convolution" and not m.is_depthwise)
                         or m.op == "aten::addmm") and n.bwd_op == "ip_grad"):
                        s1 = solution.s[t + 1][j - 1] if t < T - 1 else False
                        if solution.r[t][j - 1] or len(
                                solution.f[t][j - 1]) or s1:
                            self._bwd_schedule[t].append(
                                (k,
                                 ScheduleType(solution.r[t][j - 1], s1,
                                              solution.f[t][j - 1],
                                              False), "param"))
                    if r or len(f) or s:
                        self._bwd_schedule[t].append(
                            (k, ScheduleType(r, s, f, False), "ip"))
                elif isinstance(m, ComputeNode) and m.op == "aten::t":
                    pass
                else:
                    self._bwd_schedule[t].append((k, None, "grad"))

        self.opshapes = defaultdict()
        for k in self._outputs:
            self.opshapes[k] = [
                self.bs if dim == -1 else dim for dim in self._nodes[k].shape
            ]
Beispiel #10
0
    def extract_deps(self, graph):
        # create bwd nodes
        for i in sorted(self.nodes, reverse=True):
            if i > self.loss:
                continue
            if self.nodes[i].gnode.has_backward:
                # print(nodes[i].op)
                num_bwd = 1
                if self.nodes[i].op == "aten::addmm" or self.nodes[
                        i].op == "aten::_convolution":
                    # First param_grad, then ip_grad
                    # fwd_to_bwd points to ip_grad
                    nbwd1 = BwdNode(self.nodes[i].gnode, i, num_bwd,
                                    "param_grad")
                    nbwd2 = BwdNode(self.nodes[i].gnode, i, num_bwd, "ip_grad")
                    self.bwd_to_fwd[len(self.nodes)] = i
                    self.nodes[len(self.nodes)] = nbwd1
                    self.fwd_to_bwd[i] = len(self.nodes)
                    self.bwd_to_fwd[len(self.nodes)] = i
                    self.nodes[len(self.nodes)] = nbwd2
                else:
                    nbwd = BwdNode(self.nodes[i].gnode, i, num_bwd)
                    if self.nodes[i].has_intermediates:
                        nbwd.has_intermediates = True
                    self.fwd_to_bwd[i] = len(self.nodes)
                    self.bwd_to_fwd[len(self.nodes)] = i
                    self.nodes[len(self.nodes)] = nbwd

        self.fwd_to_bwd[self.loss] = self.loss

        # get deps for backward
        for i in self.nodes:
            if isinstance(self.nodes[i], BwdNode):
                continue
            if self.nodes[i].gnode.has_backward:
                for ni in self.nodes[i].inbound_nodes:
                    num_bwd_prev = 1
                    for p in range(num_bwd_prev):
                        self.nodes[self.fwd_to_bwd[ni]].dep_list_bwd[p].append(
                            self.fwd_to_bwd[i])
                        if self.nodes[ni].op == "aten::addmm" or self.nodes[
                                ni].op == "aten::_convolution":
                            # Add bwd deps to param_grad too
                            self.nodes[self.fwd_to_bwd[ni] -
                                       1].dep_list_bwd[p].append(
                                           self.fwd_to_bwd[i])

                ops_all = lm_ops.list_ops(self.mode, self.nodes[i].op)
                if self.nodes[i].has_intermediates:
                    ops = [ops_all[-1]]
                else:
                    ops = [ops_all[0]]
                num_bwd = len(ops)

                if self.nodes[i].op == "aten::addmm" or self.nodes[
                        i].op == "aten::_convolution":
                    assert len(self.nodes[i].inbound_nodes) <= 1
                    if len(self.nodes[i].inbound_nodes) == 1:
                        for p in range(num_bwd):
                            self.nodes[self.fwd_to_bwd[i] -
                                       1].dep_list_fwd[p].append(
                                           self.nodes[i].inbound_nodes[0])
                else:
                    for p in range(num_bwd):
                        deps = ops[p]().backward_storage
                        l = []
                        if not isinstance(deps, list):
                            deps = [deps]
                        for dep in deps:
                            if isinstance(dep, lm_ops.InputStorage):
                                for inids in dep.ids:
                                    arg_in = self.nodes[i].gnode.args[inids]
                                    if isinstance(arg_in, ComputeNode.D):
                                        innode = graph.nodes[arg_in.index]
                                        if isinstance(innode, ComputeNode):
                                            l.append(self.graph_to_solver[
                                                arg_in.index])
                            if isinstance(dep, lm_ops.OutputStorage):
                                l.append(i)
                            if isinstance(dep, lm_ops.IntermediateStorage):
                                added_int = False
                                assert self.nodes[i].has_intermediates
                                # No fwd dependency
                                # for p_option, nint_id in self.nodes[i].intermediates:
                                #     nint = self.nodes[nint_id]
                                #     assert isinstance(nint, IntNode)
                                #     if p_option == p:
                                #         l.append(nint_id)
                                #         added_int = True
                                # assert added_int == True
                        self.nodes[self.fwd_to_bwd[i]].dep_list_fwd[p] = l

            if i == self.loss - 1:  # inject loss node assuming we are at output node
                for p in range(num_bwd):
                    self.nodes[self.fwd_to_bwd[i]].dep_list_fwd[p].append(
                        self.loss)
                    if self.nodes[i].op == "aten::addmm" or self.nodes[
                            i].op == "aten::_convolution":
                        self.nodes[self.fwd_to_bwd[i] -
                                   1].dep_list_fwd[p].append(self.loss)

        for i in range(self.loss + 1):
            self.nodes[i].make_args()
        for i in range(self.loss + 1, len(self.nodes)):
            self.nodes[i].make_args()
            fwd_node = self.nodes[i].fwd_node
            stored = []
            output_shapes = []
            # Need to have a list because of nodes like add and cat
            # NOTE for both aten::add and cat, we consider output of bwd as both inputs of fwd
            for (dep, rgrad) in (fwd_node.dependencies):
                nin = graph.nodes[dep]
                if isinstance(nin, Param) and rgrad == True:
                    stored.append(
                        nin
                    )  # Params have gradients which will be stored in backward
                elif isinstance(nin, Param):
                    pass
                elif isinstance(nin, ComputeNode) and rgrad == True:
                    output_shapes.append(list(nin.shape))
                elif isinstance(nin, Input):
                    if rgrad:
                        output_shapes.append(list(nin.shape))
                else:
                    sys.exit("Unknown node encountered ")

            if (fwd_node.op == "aten::_convolution" or fwd_node.op
                    == "aten::addmm") and self.nodes[i].bwd_op == "ip_grad":
                stored = []
            if (fwd_node.op == "aten::_convolution" or fwd_node.op
                    == "aten::addmm") and self.nodes[i].bwd_op == "param_grad":
                output_shapes = []
            self.nodes[i].output_shapes = output_shapes
            self.nodes[i].stored = stored

        # Create edge_list
        self.edge_list = []
        for v in self.nodes:
            for k, vdeps in enumerate(self.nodes[v].args):
                for u in vdeps:
                    edge = (u, v, k)
                    self.edge_list.append(edge)
Beispiel #11
0
    def get_workspace_cost(self, graph, *args):
        from pathlib import Path
        import pickle
        workspace_mem = defaultdict(list)
        workspace_compute = defaultdict(list)
        recompute_workspace_mem = defaultdict(list)
        recompute_workspace_compute = defaultdict(list)
        inplace_workspace_mem = defaultdict(list)
        inplace_workspace_compute = defaultdict(list)
        cost = Path("../data/cost_" + self.data_path + "_pipeline.pkl")
        if cost.is_file():
            pkl_cost = open("../data/cost_" + self.data_path + "_pipeline.pkl",
                            'rb')
            workspace_mem, workspace_compute, recompute_workspace_mem, recompute_workspace_compute, inplace_workspace_mem, inplace_workspace_compute = pickle.load(
                pkl_cost)
            pkl_cost.close()
            assert len(workspace_compute) == len(self.nodes)
            assert len(workspace_mem) == len(self.nodes)
            for i in self.nodes:
                self.nodes[i].workspace_mem = workspace_mem[i]
                self.nodes[i].workspace_compute = workspace_compute[i]
                self.nodes[
                    i].recompute_workspace_mem = recompute_workspace_mem[i]
                self.nodes[
                    i].recompute_workspace_compute = recompute_workspace_compute[
                        i]
                self.nodes[i].inplace_workspace_mem = inplace_workspace_mem[i]
                self.nodes[
                    i].inplace_workspace_compute = inplace_workspace_compute[i]
        # Recomputation memory
        else:
            for i in sorted(self.nodes.keys(), reverse=True):
                n = self.nodes[i]
                if isinstance(n, BwdNode):
                    b = n
                    n = b.fwd_node
                    op_impls_all = lm_ops.list_ops(self.mode, n.op)
                    if self.nodes[i].has_intermediates:
                        op_impls = [op_impls_all[-1]]
                    else:
                        op_impls = [op_impls_all[0]]

                    for op_impl in op_impls:
                        fwd_working_memory, bwd_working_memory, fwd_working_memory_recompute = meminfo(
                            n, op_impl(), graph, self.bs, b.bwd_op,
                            self.select_conv_algo, self.do_inplace, *args)
                        runtime_fwd, runtime_bwd, runtime_fwd_recompute = computeinfo(
                            n, op_impl(), graph, self.bs, b.bwd_op,
                            self.select_conv_algo, self.do_inplace, *args)
                        print(b, fwd_working_memory, bwd_working_memory,
                              fwd_working_memory_recompute, runtime_fwd,
                              runtime_bwd, runtime_fwd_recompute)
                        if self.select_conv_algo and self.bwd_to_fwd[
                                i] in self.conv_list:
                            for cbwd in range(len(bwd_working_memory)):
                                self.nodes[i].workspace_mem.append(
                                    bwd_working_memory[cbwd])
                                self.nodes[i].workspace_compute.append(
                                    runtime_bwd[cbwd])
                            if not b.bwd_op == "ip_grad":
                                for cfwd in range(len(fwd_working_memory)):
                                    self.nodes[self.bwd_to_fwd[
                                        i]].workspace_mem.append(
                                            fwd_working_memory[cfwd])
                                    self.nodes[self.bwd_to_fwd[
                                        i]].workspace_compute.append(
                                            runtime_fwd[cfwd])
                        elif self.do_inplace and self.bwd_to_fwd[
                                i] in self.inplace_list:
                            self.nodes[i].workspace_mem.append(
                                bwd_working_memory[0])
                            self.nodes[i].workspace_compute.append(
                                runtime_bwd[0])
                            self.nodes[
                                self.bwd_to_fwd[i]].workspace_mem.append(
                                    fwd_working_memory[0])
                            self.nodes[
                                self.bwd_to_fwd[i]].workspace_compute.append(
                                    runtime_fwd[0])
                            self.nodes[self.bwd_to_fwd[
                                i]].inplace_workspace_mem.append(
                                    fwd_working_memory[1])
                            self.nodes[self.bwd_to_fwd[
                                i]].inplace_workspace_compute.append(
                                    runtime_fwd[1])
                        else:
                            self.nodes[i].workspace_mem.append(
                                bwd_working_memory[0])
                            self.nodes[i].workspace_compute.append(
                                runtime_bwd[0])
                            if not ((n.op == "aten::_convolution"
                                     or n.op == "aten::addmm")
                                    and b.bwd_op == "ip_grad"):
                                if n.op == "aten::batch_norm":
                                    self.nodes[self.bwd_to_fwd[
                                        i]].recompute_workspace_mem.append(
                                            fwd_working_memory_recompute[0])
                                    self.nodes[self.bwd_to_fwd[
                                        i]].recompute_workspace_compute.append(
                                            runtime_fwd_recompute[0])
                                self.nodes[
                                    self.bwd_to_fwd[i]].workspace_mem.append(
                                        fwd_working_memory[0])
                                self.nodes[self.bwd_to_fwd[
                                    i]].workspace_compute.append(
                                        runtime_fwd[0])
                else:
                    opname = n.op
                    if "int::" in opname:
                        assert isinstance(n, IntNode)
                        parent_node = self.nodes[n.solver_parent_id]
                        p = n.op_idx
                        self.nodes[i].workspace_mem.append(
                            parent_node.workspace_mem[p])
                        self.nodes[i].workspace_compute.append(
                            parent_node.workspace_compute[p] -
                            parent_node.workspace_compute[0])
                        if self.do_inplace and n.solver_parent_id in self.inplace_list:
                            self.nodes[i].inplace_workspace_mem.append(
                                parent_node.inplace_workspace_mem[p])
                            self.nodes[i].inplace_workspace_compute.append(
                                parent_node.inplace_workspace_compute[p] -
                                parent_node.inplace_workspace_compute[0])
                    elif "loss::" in opname:  # For now assuming loss calculation is operation-less
                        self.nodes[i].workspace_compute.append(0)
                        self.nodes[i].workspace_mem.append(0)
            for i in self.nodes:
                if len(self.nodes[i].recompute_workspace_mem) == 0:
                    self.nodes[i].recompute_workspace_mem = self.nodes[
                        i].workspace_mem
                    self.nodes[i].recompute_workspace_compute = self.nodes[
                        i].workspace_compute
                if len(self.nodes[i].inplace_workspace_mem) == 0:
                    self.nodes[i].inplace_workspace_mem = self.nodes[
                        i].workspace_mem
                    self.nodes[i].inplace_workspace_compute = self.nodes[
                        i].workspace_compute
            for i in self.nodes:
                workspace_compute[i] = self.nodes[i].workspace_compute
                workspace_mem[i] = self.nodes[i].workspace_mem
                recompute_workspace_compute[i] = self.nodes[
                    i].recompute_workspace_compute
                recompute_workspace_mem[i] = self.nodes[
                    i].recompute_workspace_mem
                inplace_workspace_compute[i] = self.nodes[
                    i].inplace_workspace_compute
                inplace_workspace_mem[i] = self.nodes[i].inplace_workspace_mem
            cost_file = open(
                "../data/cost_" + self.data_path + "_pipeline.pkl", 'wb')
            pickle.dump([
                workspace_mem, workspace_compute, recompute_workspace_mem,
                recompute_workspace_compute, inplace_workspace_mem,
                inplace_workspace_compute
            ], cost_file)
            cost_file.close()