def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): log.debug('Matched NearestNeighborUpsampling pattern: {}'.format( [node.id for node in match.values()])) try: input_height = match['pack_1'].in_node(1).value.item() input_width = match['pack_1'].in_node(3).value.item() height_scale = match['mul_const'].shape[-4] width_scale = match['mul_const'].shape[-2] except Exception as ex: log.warning( 'Failed to determine scaling parameters from the topology. Do not apply pattern.' ) return resample_op = ResampleOp( graph, { 'width': input_width * width_scale, 'height': input_height * height_scale, 'name': 'Resample_', 'antialias': 0, 'resample_type': 'caffe.ResampleParameter.NEAREST' }) resample_node = resample_op.create_node([match['op']]) replace_node(match['reshape_2'], resample_node) graph.remove_nodes_from( [node.id for node in match.values() if node.id != match['op'].id])
def replace_op(self, graph: nx.MultiDiGraph, node: Node): mul_op = Mul( graph, dict(name=node.id + '/mul_', symbol_dict={'name': node.id + '/mul_'})) mul_node = mul_op.create_node( inputs=[node.in_node(0), node.in_node(1)]) replace_node(node, mul_node) return [mul_node.id]
def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): decoder_node = match['decoder'] graph.remove_edge(decoder_node.id, match['sparse_to_dense'].id) graph.remove_edge(decoder_node.id, match['cast'].id) replace_node(match['sparse_to_dense'], decoder_node) # update the TensorFlow infer function for the CTCGreedyDecoder to make necessary changes with the second input decoder_node['old_infer'] = decoder_node.infer decoder_node.infer = __class__.tf_greedy_decoder_infer return {}
def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): consumers = [ n for n in match if n not in ['mul', 'op', 'add'] and not check_node_usages_out_of_match(match, n) ] if consumers: log.warning( 'PReLU pattern was detected. Non pattern consumers of nodes: "{}" were found. Won\'t replace' ''.format(', '.join([match[n].id for n in consumers]))) return gamma = match['mul'].in_node(0) if match['mul'].in_node( 1).id == match['neg_1'].id else match['mul'].in_node(1) prelu_node = PreluOp(graph, { 'name': '{}/PReLU'.format(match['add'].id) }).create_node([match['op'], gamma]) replace_node(match['add'], prelu_node) log.debug( 'PReLU pattern starting from "{}" was collapsed to "{}"'.format( match['op'].id, prelu_node.id))
def replace_sub_graph(graph: nx.MultiDiGraph, match: dict): MVN = Op.get_op_class_by_name('MVN') mvn = MVN( graph, dict(name=match['truediv'].name + '/MVN_', required_reduction_indices=[1, 2] if graph.graph['layout'] == 'NHWC' else [2, 3])) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) pow2 = match['pow'].in_node(1) eps = match['add'].in_node( 0 if match['add'].in_node(0).id != match['variance'].id else 1) new_subgraph = mvn.create_node([ match['mean'].in_node(0), mean_reduction, variance_reduction, pow2, eps ]) replace_node(match['truediv'], new_subgraph)
def test_replace_node_one_consumer(self): graph = build_graph( { 'input_1': { 'type': 'Placeholder', 'value': None, 'kind': 'op' }, 'input_2': { 'type': 'Placeholder', 'value': None, 'kind': 'op' }, 'old': { 'type': 'Identity', 'value': None, 'kind': 'op', 'is_output': True }, 'output': { 'type': 'OpOutput', 'value': None, 'kind': 'op' }, }, [('input_1', 'old'), ('input_2', 'old'), ('old', 'output')]) new_node = Const(graph, { 'name': 'new' }).create_node([Node(graph, 'input_1'), Node(graph, 'input_2')]) replace_node(Node(graph, 'old'), new_node) self.assertEqual(len(graph.nodes()), 4) self.assertEqual(len(graph.edges()), 3) self.assertEqual(new_node['is_output'], True) self.assertListEqual(list(graph.out_edges('new')), [('new', 'output')])
def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): fbn = match['fbn'] input = fbn.in_node(0) log.debug('Found potential MVN pattern after {} with name {}'.format(input.op, input.name)) if input.id != match['mean'].in_node(0).id or input.id != match['sqdiff'].in_node(0).id: return log.debug('Confirmed MVN pattern after {} with name {}'.format(input.op, input.name)) MVN = Op.get_op_class_by_name('MVN') mvn = MVN(graph, dict( name=fbn.name + '/MVN_', eps=fbn.eps, required_reduction_indices=[1, 2] if fbn.data_format == b'NHWC' else [2, 3] )) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mul = Eltwise(graph, dict(operation='mul', name=fbn.name + '/Mul_')) add = Eltwise(graph, dict(operation='sum', name=fbn.name + '/Add_')) input_gamma = fbn.in_node(1) input_beta = fbn.in_node(2) mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) new_subgraph = add.create_node([ mul.create_node([ mvn.create_node([input, mean_reduction, variance_reduction]), input_gamma ]), input_beta ]) replace_node(fbn, new_subgraph)
def test_replace_node_several_consumers(self): graph = build_graph( { 'input_1': { 'type': 'Placeholder', 'value': None, 'kind': 'op' }, 'input_2': { 'type': 'Placeholder', 'value': None, 'kind': 'op' }, 'old': { 'type': 'Identity', 'value': None, 'kind': 'op' }, 'output_1': { 'type': 'Identity', 'value': None, 'kind': 'op' }, 'output_2': { 'type': 'Identity', 'value': None, 'kind': 'op' }, 'output_3': { 'type': 'Identity', 'value': None, 'kind': 'op' }, }, [ ('input_1', 'old'), ('input_2', 'old'), ('old', 'output_3'), ('old', 'output_2'), ('old', 'output_1'), ]) new_node = Const(graph, { 'name': 'new' }).create_node([Node(graph, 'input_1'), Node(graph, 'input_2')]) replace_node(Node(graph, 'old'), new_node) self.assertEqual(len(graph.nodes()), 6) self.assertEqual(len(graph.edges()), 5) self.assertListEqual(sorted(graph.out_edges('new')), [('new', 'output_1'), ('new', 'output_2'), ('new', 'output_3')]) expected_result = [('new', 'output_1', { 'in': 0, 'out': 2, 'name': 'old' }), ('new', 'output_2', { 'in': 0, 'out': 1, 'name': 'old' }), ('new', 'output_3', { 'in': 0, 'out': 0, 'name': 'old' })] self.assertListEqual(sorted(graph.out_edges('new', data=True)), expected_result)
def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): # node that is used to identify this pattern application instance for switching between supported # and not supported LSTMCell sub-graphs; this value will be searched in __class__.instances_supported_by_IE. anchor_node = match[__class__.anchor()] assert anchor_node.has_valid('name'), \ 'LSTMCell anchor node {} does\'t have attribute name; such nodes are not supported.' if __class__.second_round and anchor_node.name not in __class__.instances_supported_by_IE: # at the second round of conversion we apply pattern selectively: only instances from # __class__.instances_supported_by_IE are allowed for conversion; all others should be skipped return match['input_op'] = match['concat'].in_node(0) match['input_hidden_state'] = match['concat'].in_node(1) match['input_cell_state'] = match['mul_0'].in_node(0) if match['mul_0'].in_node(0).id != match['sigmoid_0'].id \ else match['mul_0'].in_node(1) pattern_edges = self.pattern()['edges'] pattern_edges.extend([('input_op', 'concat'), ('input_cell_state', 'mul_0'), ('input_hidden_state', 'concat')]) inputs = get_inputs_with_ports( graph, match, pattern_edges, __class__.inputs + __class__.extra_inputs) lstm_op = LSTMCell( graph, dict( name=match['concat'].name + '/LSTMCell', mark_supported_by_IE=__class__.mark_supported_by_IE, original_name=anchor_node.name, finalize_first_round=__class__.finalize_first_round, )) lstm_node = lstm_op.create_node(inputs) lstm_node['old_infer'] = lstm_node.infer lstm_node.infer = __class__.infer # this node consumes one of the resulting LSTMCell outputs, # it should be removed before reconnecting the nodes, # otherwise it will be reconnected to the new cell output graph.remove_node(match['tanh_1'].id) for i, output in enumerate(__class__.outputs): replace_node(match[output], lstm_node, i) # Because of LSTMCell specification, this layer MUST have 2 outputs. # => we need to create fake consumers for LSTMCell # when this node haven't some outputs. for i in [0, 1]: if i not in lstm_node.out_nodes(): fake_output_node = Output( graph, dict(name=lstm_node.name + "/Output_{}".format(i))) fake_output_node.create_node(inputs=[lstm_node], edge_attrs={ 'out': i, 'in': 0 }) lstm_node['tf'] = True lstm_node['extra_inputs'] = { name: match[name].id for name in __class__.extra_inputs } lstm_node['inputs'] = { name: match[name].id for name in __class__.inputs }