コード例 #1
0
def test_id_match(patrn_ugraph):
    matcher = uTensorGraphMatcher(
        patrn_ugraph, op_equality_delegate=uTensorOpEqualityDelegate)
    matches = matcher.match(patrn_ugraph)
    assert matches, 'expecting matches, get {} matches'.format(len(matches))
    match = matches[0]

    assert match.patrn2subj_op_map['input0'].name in ['input0', 'input1']
    assert match.patrn2subj_op_map['input1'].name in ['input0', 'input1']
    assert match.patrn2subj_op_map['input0'].name != match.patrn2subj_op_map[
        'input1'].name
    assert match.patrn2subj_op_map['add0'].name == 'add0'
    assert match.patrn2subj_op_map['output'].name == 'output'

    assert match.subj2patrn_op_map['input0'].name in ['input0', 'input1']
    assert match.subj2patrn_op_map['input1'].name in ['input0', 'input1']
    assert match.subj2patrn_op_map['input0'].name != match.subj2patrn_op_map[
        'input1'].name
    assert match.subj2patrn_op_map['add0'].name == 'add0'
    assert match.subj2patrn_op_map['output'].name == 'output'

    for tensor in patrn_ugraph.input_tensors:
        assert tensor.name in match.patrn2subj_tensor_map, \
            '{} is missing'.format(tensor.name)
    for tensor in patrn_ugraph.output_tensors:
        assert tensor.name in match.subj2patrn_tensor_map, \
            '{} is missing'.format(tensor.name)
コード例 #2
0
def test_replace_fc_with_add(subj_graph_1, patrn_fc_1):
    def callback(match):
        graph = tf.Graph()
        with graph.as_default():
            a = tf.placeholder(dtype=tf.float32, name='a')
            b = tf.placeholder(dtype=tf.float32, name='b')
            out = tf.add(a, b, name='fused_node')
        ugraph = GraphDefParser(config={}).parse(graph.as_graph_def(), output_nodes=[out.op.name])
        ugraph.ops_info['fused_node'].replace_with_null_input_tensor(0)
        ugraph.ops_info['fused_node'].replace_with_null_input_tensor(1)
        topologic_order_graph(ugraph)
        ugraph = prune_graph(ugraph)
        patrn_ugraph = match.pattern_ugraph
        
        input_map = {
            patrn_ugraph.ops_info['a_prime'].input_tensors[0]: ugraph.ops_info['fused_node'].input_tensors[0],
            patrn_ugraph.ops_info['a_prime'].input_tensors[1]: ugraph.ops_info['fused_node'].input_tensors[1]
        }
        output_map = {
            patrn_ugraph.ops_info['r_prime'].output_tensors[0]: ugraph.ops_info['fused_node'].output_tensors[0]
        }
        return ugraph, input_map, output_map
    matcher = uTensorGraphMatcher(patrn_fc_1, op_equality_delegate=uTensorOpEqualityDelegate)
    matches = matcher.match(subj_graph_1)
    assert matches, 'no match found'
    match = matches[0]
    new_ugraph = match.replace_with(callback)
    test_pass = True
    missed_op_names = []
    for op_name in match.subj2patrn_op_map:
        if op_name in new_ugraph.ops_info:
            test_pass = False
            missed_op_names.append(op_name)
    assert test_pass, \
        'these ops should not be found in the new ugrah: {}'.format(missed_op_names)
コード例 #3
0
 def transform(self, ugraph):
     # FIXME: should use a generic op_equality_delegate
     matcher = uTensorGraphMatcher(
         pattern_ugraph=self.pattern_ugraph,
         op_equality_delegate=uTensorOpEqualityDelegate)
     matches = matcher.match(ugraph, 1)
     while matches:
         match = matches[0]
         ugraph = match.replace_with(callback=self)
         matches = matcher.match(ugraph, 1)
     return ugraph
コード例 #4
0
 def _transform_tf(self, ugraph):
     # FIXME: should use a generic op_equality_delegate
     matcher = uTensorGraphMatcher(
         pattern_ugraph=self.pattern_ugraph,
         op_equality_delegate=uTensorOpEqualityDelegate)
     matches = matcher.match(ugraph, n=1)
     while matches:
         match = matches[0]
         ugraph = self._handle_match_tf(match)
         matches = matcher.match(ugraph)
     return ugraph
コード例 #5
0
 def transform(self, ugraph):
     if ugraph.lib_name != 'tensorflow':
         raise ValueError('only support tensorflow graph')
     # FIXME: should use a generic op_equality_delegate
     matcher = uTensorGraphMatcher(
         pattern_ugraph=self.pattern_ugraph,
         op_equality_delegate=uTensorOpEqualityDelegate)
     matches = matcher.match(ugraph, n=1)
     while matches:
         match = matches[0]
         ugraph = match.replace_with(callback=self)
         matches = matcher.match(ugraph, n=1)
     return ugraph
コード例 #6
0
def test_match_sub1_2(patrn_ugraph, subject_ugraph1_2):
    matcher = uTensorGraphMatcher(
        patrn_ugraph, op_equality_delegate=uTensorOpEqualityDelegate)
    matches = matcher.match(subject_ugraph1_2)
    assert matches, 'expecting matches, get {} matches'.format(len(matches))
    match = matches[0]
    assert match.patrn2subj_op_map['input0'].name in [
        'sub_input0', 'sub_input1'
    ]
    assert match.patrn2subj_op_map['input1'].name in [
        'sub_input0', 'sub_input1'
    ]
    assert match.patrn2subj_op_map['add0'].name == 'sub_add0'
    assert match.patrn2subj_op_map['output'].name == 'sub_add1'
コード例 #7
0
def test_match_sub1(patrn_ugraph, subject_ugraph1):
    matcher = uTensorGraphMatcher(
        patrn_ugraph, op_equality_delegate=uTensorOpEqualityDelegate)
    matches = matcher.match_all(subject_ugraph1)
    assert matches, 'expecting matches, get {} matches'.format(len(matches))
    match = matches[0]
    assert len(matches) == 2, 'should be exactly two match, get {}'.format(
        len(matches))
    assert match.patrn2subj_op_map['input0'].name in [
        'sub_input0', 'sub_input1'
    ], match
    assert match.patrn2subj_op_map['input1'].name in [
        'sub_input0', 'sub_input1'
    ], match
    assert match.patrn2subj_op_map['input0'].name != match.patrn2subj_op_map[
        'input1'].name
    assert match.patrn2subj_op_map['add0'].name == 'sub_add0', match
    assert match.patrn2subj_op_map['output'].name == 'sub_add1', match