Exemplo n.º 1
0
    def _prune_graph(cls, ugraph):
        """Remove nodes that is no longer needed
    """
        new_ugraph = deepcopy(ugraph)
        # BFS to find all ops you need
        ops_in_need = set(ugraph.output_nodes)
        queue = [name for name in ugraph.output_nodes]
        visited = set([])
        while queue:
            op_name = queue.pop(0)
            op_info = new_ugraph.ops_info[op_name]
            in_ops = [
                parse_tensor_name(t_info.name)[0]
                for t_info in op_info.input_tensors
            ]
            queue.extend([name for name in in_ops if name not in visited])
            visited.update(in_ops)
            ops_in_need.update(in_ops)

        for op_name in new_ugraph.topo_order:
            if op_name not in ops_in_need:
                # remove ops not needed from ops_info
                new_ugraph.ops_info.pop(op_name)
        new_ugraph.topo_order = [
            op_name for op_name in new_ugraph.topo_order
            if op_name in new_ugraph.ops_info
        ]
        return new_ugraph
Exemplo n.º 2
0
        def visit(node_name):
            if node_name in perm_visit:
                return
            if node_name in visited:
                raise ValueError("Input graph is not a DAG")

            visited.add(node_name)
            op_info = self.ops_info[node_name]

            for t_info in op_info.input_tensors:
                op_name = parse_tensor_name(t_info.name)[0]
                visit(op_name)

            perm_visit.add(node_name)
            ops_torder.insert(0, node_name)
Exemplo n.º 3
0
 def transform(self, ugraph):
     new_graph = uTensorGraph()
     dropout_input_map = self._find_input(ugraph)
     new_ops_info = {}
     for node_name in ugraph.topo_order:
         match = self.TARGET_NODENAME_PATTERN.match(node_name)
         if match:
             # ignore all dropout nodes
             continue
         # replace inputs with dropout inputs
         op_info = ugraph.ops_info[node_name]
         in_t_infos = [
             deepcopy(t_info, {'ugraph': new_graph})
             for t_info in op_info.input_tensors
         ]
         out_t_infos = [
             deepcopy(t_info, {'ugraph': new_graph})
             for t_info in op_info.output_tensors
         ]
         op_attr = deepcopy(op_info.op_attr)
         for i, t_info in enumerate(in_t_infos):
             op_name = parse_tensor_name(t_info.name)[0]
             match = self.TARGET_NODENAME_PATTERN.match(op_name)
             if match:
                 name_scope = match.group(1)
                 # assume there should be only on input except keep_prob
                 dropout_in_tensor = dropout_input_map[name_scope]
                 in_t_infos.pop(i)
                 in_t_infos.insert(i, dropout_in_tensor)
         new_op_info = OperationInfo(name=op_info.name,
                                     input_tensors=in_t_infos,
                                     output_tensors=out_t_infos,
                                     op_type=op_info.op_type,
                                     backend=op_info.backend,
                                     op_attr=op_attr,
                                     ugraph=new_graph)
         new_ops_info[node_name] = new_op_info
     new_graph.ops_info = new_ops_info
     new_graph.output_nodes = ugraph.output_nodes
     new_graph._backend = ugraph._backend
     return new_graph
  def _find_input(self, ugraph):
    """dropout_name --> input_tensor_info

    input_tensor_info := the tensor info of a tensor which is not generated
                         in the dropout namescope but is consumed by ops in
                         dropout namescope with name not starts with 'keep_prob'
    """
    clusters = self._find_dropout_clusters(ugraph)
    input_map = {}
    for node_name in ugraph.topo_order:
      match = self.TARGET_NODENAME_PATTERN.match(node_name)
      if match:
        name_scope = match.group(1)
        cluster = clusters[name_scope]
        op_info = ugraph.ops_info[node_name]
        for in_tensor_info in op_info.input_tensors:
          in_op_name = parse_tensor_name(in_tensor_info.name)[0]
          if in_op_name not in cluster and not in_op_name.startswith('keep_prob'):
            input_map[name_scope] = in_tensor_info
            # assuming there is only one input for dropout
            break
    return input_map