示例#1
0
 def __call__(self, match):
     graph = tf.Graph()
     subj_pool_name = match.patrn2subj_op_map['max_pool'].name
     subj_pool_op = match.subject_ugraph[subj_pool_name]
     ksize = subj_pool_op.op_attr['ksize'].value.ints_value[:]
     strides = subj_pool_op.op_attr['strides'].value.ints_value[:]
     padding = subj_pool_op.op_attr['padding'].value
     with graph.as_default():
         dummy_input = tf.placeholder(dtype=tf.float32,
                                      shape=[None, 128, 128, 3])
         max_pool = tf.nn.max_pool(dummy_input,
                                   ksize=ksize,
                                   strides=strides,
                                   padding=padding,
                                   name='max_pool')
         tf.nn.relu(max_pool, name='relu')
     ugraph = GraphDefParser(config={}).parse(graph.as_graph_def(),
                                              output_nodes=['relu'])
     ugraph['max_pool'].replace_with_null_input_tensor(0)
     ugraph = prune_graph(ugraph)
     topologic_order_graph(ugraph)
     input_map = {
         match.pattern_ugraph['relu'].input_tensors[0]:
         ugraph['max_pool'].input_tensors[0]
     }
     output_map = {
         match.pattern_ugraph['max_pool'].output_tensors[0]:
         ugraph['relu'].output_tensors[0]
     }
     return ugraph, input_map, output_map
示例#2
0
 def pattern_ugraph(self):
     graph = tf.Graph()
     with graph.as_default():
         dummy_input = tf.placeholder(dtype=tf.float32,
                                      shape=[None, 128, 128, 3],
                                      name='dummy_input')
         dummy_weight = tf.zeros([32, 32, 3, 10],
                                 dtype=tf.float32,
                                 name='dummy_weight')
         conv = tf.nn.conv2d(dummy_input,
                             dummy_weight,
                             strides=[1, 2, 2, 1],
                             padding='VALID',
                             name='conv')
         maxpool = tf.nn.max_pool(conv,
                                  ksize=[1, 2, 2, 1],
                                  strides=[1, 2, 2, 1],
                                  padding='VALID',
                                  name='maxpool')
     ugraph = GraphDefParser(config={}).parse(
         graph.as_graph_def(), output_nodes=[maxpool.op.name])
     quant_ugraph = QuantizeTransformer().transform(ugraph)
     patrn_ugraph = deepcopy(quant_ugraph)
     quant_conv_op = patrn_ugraph['conv/eightbit']
     for i, _ in enumerate(quant_conv_op.input_tensors):
         quant_conv_op.replace_with_null_input_tensor(i)
     patrn_ugraph.output_nodes = ['maxpool/eightbit']
     patrn_ugraph = prune_graph(patrn_ugraph)
     topologic_order_graph(patrn_ugraph)
     return patrn_ugraph
示例#3
0
def fully_connect_pattern1():
    patrn_graph = tf.Graph()
    with patrn_graph.as_default():
        z_prime = tf.placeholder(name='z_prime', dtype=tf.float32)
        w_prime = tf.constant(np.random.rand(3, 3), name='w_prime', dtype=tf.float32)
        a_prime = tf.matmul(z_prime, w_prime, name='a_prime')
        r_prime = tf.nn.relu(a_prime, name='r_prime')
    patrn_ugraph = GraphDefParser(config={}).parse(patrn_graph.as_graph_def(), output_nodes=[r_prime.op.name])
    for i in range(2):
        patrn_ugraph.ops_info['a_prime'].replace_with_null_input_tensor(i)
    patrn_ugraph = prune_graph(patrn_ugraph)
    topologic_order_graph(patrn_ugraph)
    return patrn_ugraph
示例#4
0
 def _handle_match_tf(self, match):
     subj_ugraph = match.subject_ugraph
     subj_in_tensor = (match.patrn2subj_op_map['dropout/truediv'].
                       input_tensors[0].op.output_tensors[0])
     subj_out_op = match.patrn2subj_op_map['dropout/mul']
     subj_out_tensor = subj_out_op.output_tensors[0]
     for op in subj_out_op.output_nodes:
         for idx, tensor in enumerate(op.input_tensors):
             if tensor.name == subj_out_tensor.name:
                 op.input_tensors[idx] = subj_in_tensor
     for idx, op_name in enumerate(subj_ugraph.output_nodes):
         if op_name == subj_out_op.name:
             subj_ugraph.output_nodes[idx] = subj_in_tensor.op_name
     match.subject_ugraph = prune_graph(subj_ugraph)
     topologic_order_graph(match.subject_ugraph)
     return match.subject_ugraph
示例#5
0
 def pattern_ugraph(self):
     graph = tf.Graph()
     with graph.as_default():
         dummy_input = tf.placeholder(dtype=tf.float32,
                                      shape=[None, 128, 128, 3])
         relu = tf.nn.relu(dummy_input, name='relu')
         tf.nn.max_pool(relu,
                        ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1],
                        padding='SAME',
                        name='max_pool')
     pattern_ugraph = GraphDefParser(config={}).parse(
         graph.as_graph_def(), output_nodes=['max_pool'])
     pattern_ugraph['relu'].replace_with_null_input_tensor(0)
     pattern_ugraph = prune_graph(pattern_ugraph)
     topologic_order_graph(pattern_ugraph)
     return pattern_ugraph
示例#6
0
 def pattern_ugraph(self):
     graph = tf.Graph()
     with graph.as_default():
         dummy_x = tf.constant(np.random.rand(10, 10),
                               dtype=tf.float32,
                               name='dummy_x')
         dummy_rate = tf.placeholder(dtype=tf.float32, name='dummy_rate')
         dropout = tf.nn.dropout(dummy_x, rate=dummy_rate, name='dropout')
     patrn_ugraph = GraphDefParser(config={}).parse(
         graph.as_graph_def(), output_nodes=[dropout.op.name])
     # replace dummy_x
     patrn_ugraph['dropout/truediv'].replace_with_null_input_tensor(0)
     # # replace dummy_rate
     patrn_ugraph['dropout/sub'].replace_with_null_input_tensor(1)
     # # replace Shape Op
     patrn_ugraph[
         'dropout/random_uniform/RandomUniform'].replace_with_null_input_tensor(
             0)
     patrn_ugraph = prune_graph(patrn_ugraph)
     topologic_order_graph(patrn_ugraph)
     return patrn_ugraph
示例#7
0
 def callback(match):
     graph = tf.Graph()
     with graph.as_default():
         a = tf.placeholder(dtype=tf.float32, name='a')
         b = tf.placeholder(dtype=tf.float32, name='b')
         out = tf.add(a, b, name='fused_node')
     ugraph = GraphDefParser(config={}).parse(graph.as_graph_def(), output_nodes=[out.op.name])
     ugraph.ops_info['fused_node'].replace_with_null_input_tensor(0)
     ugraph.ops_info['fused_node'].replace_with_null_input_tensor(1)
     topologic_order_graph(ugraph)
     ugraph = prune_graph(ugraph)
     patrn_ugraph = match.pattern_ugraph
     
     input_map = {
         patrn_ugraph.ops_info['a_prime'].input_tensors[0]: ugraph.ops_info['fused_node'].input_tensors[0],
         patrn_ugraph.ops_info['a_prime'].input_tensors[1]: ugraph.ops_info['fused_node'].input_tensors[1]
     }
     output_map = {
         patrn_ugraph.ops_info['r_prime'].output_tensors[0]: ugraph.ops_info['fused_node'].output_tensors[0]
     }
     return ugraph, input_map, output_map
示例#8
0
    def replace_with(self, callback, suffix=None):
        """
    Replace matched subgraph with a given ugraph given by the callback, **not** in-place

    :param callback: a callable object which takes a :py:class:`.uTensorGraphMatch` and
      reutrn three values -- a :py:class:`.uTensorGraph` object to replaced the matched
      subgraph with (the ``replacing graph``), an ``input_map`` (dict) maps input tensors 
      in pattern graph to the input tensors in replacing graph and an ``output_map`` (dict)
      which maps the output tensors
    :type callback: callable

    :param suffix: (optional) the suffix to add to the name of ops/tensors in the replacing
      graph returned by ``callback``. If not given, it will be a random string
    :type suffix: str

    :rtype: :py:class:`.uTensorGraph`, a **new** graph with matched subgraph replaced
    """
        # build a matched subgraph and pass it to callback
        # input/output_map (dict):
        #  {
        #     tensor in pattern graph : tensor in replacing graph
        #  }
        replace_ugraph, input_map, output_map = callback(self)
        replaceible, reasons = self._is_replacible(replace_ugraph, input_map,
                                                   output_map)
        if not replaceible:
            raise ValueError(
                'matched subgraph can not be replaced with the ugraph given: {}'
                .format(reasons))
        replace_ugraph, input_map, output_map = self.new_ugraph_with_suffix(
            replace_ugraph, input_map, output_map, suffix)
        new_ugraph = deepcopy(self.subject_ugraph)
        # make replace_ugraph be a subgraph in the new_ugraph
        replace_ugraph.unsafe_merge_into(new_ugraph)
        for tensor in input_map.values():
            tensor.move_into(new_ugraph)
        for tensor in output_map.values():
            tensor.move_into(new_ugraph)
        subj_graph_view = self.subject_graph_view
        # replacing output tensors
        for out_tensor in subj_graph_view.output_tensors:
            repl_out_tensor = output_map[self.subj2patrn_tensor_map[
                out_tensor.name]]
            out_ops = [
                new_ugraph[op.name] for op in out_tensor.op.output_nodes
            ]
            for op in out_ops:
                for i, tensor in enumerate(op.input_tensors):
                    if tensor.name == out_tensor.name:
                        op.input_tensors[i] = repl_out_tensor
            for i, node_name in enumerate(new_ugraph.output_nodes):
                if node_name == out_tensor.op.name:
                    new_ugraph.output_nodes[i] = repl_out_tensor.op.name
        # replacing input tensors
        inv_input_map = dict([(v, k) for k, v in input_map.items()])
        for op in replace_ugraph.input_ops:
            for i, repl_in_tensor in enumerate(op.input_tensors):
                patrn_in_tensor = inv_input_map[repl_in_tensor]
                subj_in_tensor = self.patrn2subj_tensor_map[
                    patrn_in_tensor.name]
                op.input_tensors[i] = subj_in_tensor
        new_ugraph.ops_info.update(replace_ugraph.ops_info)
        topologic_order_graph(new_ugraph)
        new_ugraph = prune_graph(new_ugraph)
        return new_ugraph