def test_id_match(patrn_ugraph): matcher = uTensorGraphMatcher( patrn_ugraph, op_equality_delegate=uTensorOpEqualityDelegate) matches = matcher.match(patrn_ugraph) assert matches, 'expecting matches, get {} matches'.format(len(matches)) match = matches[0] assert match.patrn2subj_op_map['input0'].name in ['input0', 'input1'] assert match.patrn2subj_op_map['input1'].name in ['input0', 'input1'] assert match.patrn2subj_op_map['input0'].name != match.patrn2subj_op_map[ 'input1'].name assert match.patrn2subj_op_map['add0'].name == 'add0' assert match.patrn2subj_op_map['output'].name == 'output' assert match.subj2patrn_op_map['input0'].name in ['input0', 'input1'] assert match.subj2patrn_op_map['input1'].name in ['input0', 'input1'] assert match.subj2patrn_op_map['input0'].name != match.subj2patrn_op_map[ 'input1'].name assert match.subj2patrn_op_map['add0'].name == 'add0' assert match.subj2patrn_op_map['output'].name == 'output' for tensor in patrn_ugraph.input_tensors: assert tensor.name in match.patrn2subj_tensor_map, \ '{} is missing'.format(tensor.name) for tensor in patrn_ugraph.output_tensors: assert tensor.name in match.subj2patrn_tensor_map, \ '{} is missing'.format(tensor.name)
def test_replace_fc_with_add(subj_graph_1, patrn_fc_1): def callback(match): graph = tf.Graph() with graph.as_default(): a = tf.placeholder(dtype=tf.float32, name='a') b = tf.placeholder(dtype=tf.float32, name='b') out = tf.add(a, b, name='fused_node') ugraph = GraphDefParser(config={}).parse(graph.as_graph_def(), output_nodes=[out.op.name]) ugraph.ops_info['fused_node'].replace_with_null_input_tensor(0) ugraph.ops_info['fused_node'].replace_with_null_input_tensor(1) topologic_order_graph(ugraph) ugraph = prune_graph(ugraph) patrn_ugraph = match.pattern_ugraph input_map = { patrn_ugraph.ops_info['a_prime'].input_tensors[0]: ugraph.ops_info['fused_node'].input_tensors[0], patrn_ugraph.ops_info['a_prime'].input_tensors[1]: ugraph.ops_info['fused_node'].input_tensors[1] } output_map = { patrn_ugraph.ops_info['r_prime'].output_tensors[0]: ugraph.ops_info['fused_node'].output_tensors[0] } return ugraph, input_map, output_map matcher = uTensorGraphMatcher(patrn_fc_1, op_equality_delegate=uTensorOpEqualityDelegate) matches = matcher.match(subj_graph_1) assert matches, 'no match found' match = matches[0] new_ugraph = match.replace_with(callback) test_pass = True missed_op_names = [] for op_name in match.subj2patrn_op_map: if op_name in new_ugraph.ops_info: test_pass = False missed_op_names.append(op_name) assert test_pass, \ 'these ops should not be found in the new ugrah: {}'.format(missed_op_names)
def transform(self, ugraph): # FIXME: should use a generic op_equality_delegate matcher = uTensorGraphMatcher( pattern_ugraph=self.pattern_ugraph, op_equality_delegate=uTensorOpEqualityDelegate) matches = matcher.match(ugraph, 1) while matches: match = matches[0] ugraph = match.replace_with(callback=self) matches = matcher.match(ugraph, 1) return ugraph
def _transform_tf(self, ugraph): # FIXME: should use a generic op_equality_delegate matcher = uTensorGraphMatcher( pattern_ugraph=self.pattern_ugraph, op_equality_delegate=uTensorOpEqualityDelegate) matches = matcher.match(ugraph, n=1) while matches: match = matches[0] ugraph = self._handle_match_tf(match) matches = matcher.match(ugraph) return ugraph
def transform(self, ugraph): if ugraph.lib_name != 'tensorflow': raise ValueError('only support tensorflow graph') # FIXME: should use a generic op_equality_delegate matcher = uTensorGraphMatcher( pattern_ugraph=self.pattern_ugraph, op_equality_delegate=uTensorOpEqualityDelegate) matches = matcher.match(ugraph, n=1) while matches: match = matches[0] ugraph = match.replace_with(callback=self) matches = matcher.match(ugraph, n=1) return ugraph
def test_match_sub1_2(patrn_ugraph, subject_ugraph1_2): matcher = uTensorGraphMatcher( patrn_ugraph, op_equality_delegate=uTensorOpEqualityDelegate) matches = matcher.match(subject_ugraph1_2) assert matches, 'expecting matches, get {} matches'.format(len(matches)) match = matches[0] assert match.patrn2subj_op_map['input0'].name in [ 'sub_input0', 'sub_input1' ] assert match.patrn2subj_op_map['input1'].name in [ 'sub_input0', 'sub_input1' ] assert match.patrn2subj_op_map['add0'].name == 'sub_add0' assert match.patrn2subj_op_map['output'].name == 'sub_add1'
def test_match_sub1(patrn_ugraph, subject_ugraph1): matcher = uTensorGraphMatcher( patrn_ugraph, op_equality_delegate=uTensorOpEqualityDelegate) matches = matcher.match_all(subject_ugraph1) assert matches, 'expecting matches, get {} matches'.format(len(matches)) match = matches[0] assert len(matches) == 2, 'should be exactly two match, get {}'.format( len(matches)) assert match.patrn2subj_op_map['input0'].name in [ 'sub_input0', 'sub_input1' ], match assert match.patrn2subj_op_map['input1'].name in [ 'sub_input0', 'sub_input1' ], match assert match.patrn2subj_op_map['input0'].name != match.patrn2subj_op_map[ 'input1'].name assert match.patrn2subj_op_map['add0'].name == 'sub_add0', match assert match.patrn2subj_op_map['output'].name == 'sub_add1', match