Ejemplo n.º 1
0
    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
        log.debug('Matched NearestNeighborUpsampling pattern: {}'.format(
            [node.id for node in match.values()]))
        try:
            input_height = match['pack_1'].in_node(1).value.item()
            input_width = match['pack_1'].in_node(3).value.item()

            height_scale = match['mul_const'].shape[-4]
            width_scale = match['mul_const'].shape[-2]
        except Exception as ex:
            log.warning(
                'Failed to determine scaling parameters from the topology. Do not apply pattern.'
            )
            return

        resample_op = ResampleOp(
            graph, {
                'width': input_width * width_scale,
                'height': input_height * height_scale,
                'name': 'Resample_',
                'antialias': 0,
                'resample_type': 'caffe.ResampleParameter.NEAREST'
            })
        resample_node = resample_op.create_node([match['op']])

        replace_node(match['reshape_2'], resample_node)
        graph.remove_nodes_from(
            [node.id for node in match.values() if node.id != match['op'].id])
Ejemplo n.º 2
0
 def replace_op(self, graph: nx.MultiDiGraph, node: Node):
     mul_op = Mul(
         graph,
         dict(name=node.id + '/mul_',
              symbol_dict={'name': node.id + '/mul_'}))
     mul_node = mul_op.create_node(
         inputs=[node.in_node(0), node.in_node(1)])
     replace_node(node, mul_node)
     return [mul_node.id]
Ejemplo n.º 3
0
    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
        decoder_node = match['decoder']
        graph.remove_edge(decoder_node.id, match['sparse_to_dense'].id)
        graph.remove_edge(decoder_node.id, match['cast'].id)
        replace_node(match['sparse_to_dense'], decoder_node)

        # update the TensorFlow infer function for the CTCGreedyDecoder to make necessary changes with the second input
        decoder_node['old_infer'] = decoder_node.infer
        decoder_node.infer = __class__.tf_greedy_decoder_infer
        return {}
Ejemplo n.º 4
0
 def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
     consumers = [
         n for n in match if n not in ['mul', 'op', 'add']
         and not check_node_usages_out_of_match(match, n)
     ]
     if consumers:
         log.warning(
             'PReLU pattern was detected. Non pattern consumers of nodes: "{}" were found. Won\'t replace'
             ''.format(', '.join([match[n].id for n in consumers])))
         return
     gamma = match['mul'].in_node(0) if match['mul'].in_node(
         1).id == match['neg_1'].id else match['mul'].in_node(1)
     prelu_node = PreluOp(graph, {
         'name': '{}/PReLU'.format(match['add'].id)
     }).create_node([match['op'], gamma])
     replace_node(match['add'], prelu_node)
     log.debug(
         'PReLU pattern starting from "{}" was collapsed to "{}"'.format(
             match['op'].id, prelu_node.id))
Ejemplo n.º 5
0
    def replace_sub_graph(graph: nx.MultiDiGraph, match: dict):
        MVN = Op.get_op_class_by_name('MVN')

        mvn = MVN(
            graph,
            dict(name=match['truediv'].name + '/MVN_',
                 required_reduction_indices=[1, 2]
                 if graph.graph['layout'] == 'NHWC' else [2, 3]))
        mvn.attrs['old_infer'] = mvn.attrs['infer']
        mvn.attrs['infer'] = __class__.infer

        mean_reduction = match['mean'].in_node(1)
        variance_reduction = match['variance'].in_node(1)
        pow2 = match['pow'].in_node(1)
        eps = match['add'].in_node(
            0 if match['add'].in_node(0).id != match['variance'].id else 1)

        new_subgraph = mvn.create_node([
            match['mean'].in_node(0), mean_reduction, variance_reduction, pow2,
            eps
        ])

        replace_node(match['truediv'], new_subgraph)
Ejemplo n.º 6
0
    def test_replace_node_one_consumer(self):
        graph = build_graph(
            {
                'input_1': {
                    'type': 'Placeholder',
                    'value': None,
                    'kind': 'op'
                },
                'input_2': {
                    'type': 'Placeholder',
                    'value': None,
                    'kind': 'op'
                },
                'old': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op',
                    'is_output': True
                },
                'output': {
                    'type': 'OpOutput',
                    'value': None,
                    'kind': 'op'
                },
            }, [('input_1', 'old'), ('input_2', 'old'), ('old', 'output')])

        new_node = Const(graph, {
            'name': 'new'
        }).create_node([Node(graph, 'input_1'),
                        Node(graph, 'input_2')])
        replace_node(Node(graph, 'old'), new_node)

        self.assertEqual(len(graph.nodes()), 4)
        self.assertEqual(len(graph.edges()), 3)
        self.assertEqual(new_node['is_output'], True)
        self.assertListEqual(list(graph.out_edges('new')), [('new', 'output')])
Ejemplo n.º 7
0
    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
        fbn = match['fbn']
        input = fbn.in_node(0)
        log.debug('Found potential MVN pattern after {} with name {}'.format(input.op, input.name))
        if input.id != match['mean'].in_node(0).id or input.id != match['sqdiff'].in_node(0).id:
            return

        log.debug('Confirmed MVN pattern after {} with name {}'.format(input.op, input.name))
        MVN = Op.get_op_class_by_name('MVN')

        mvn = MVN(graph, dict(
            name=fbn.name + '/MVN_',
            eps=fbn.eps,
            required_reduction_indices=[1, 2] if fbn.data_format == b'NHWC' else [2, 3]
        ))
        mvn.attrs['old_infer'] = mvn.attrs['infer']
        mvn.attrs['infer'] = __class__.infer

        mul = Eltwise(graph, dict(operation='mul', name=fbn.name + '/Mul_'))
        add = Eltwise(graph, dict(operation='sum', name=fbn.name + '/Add_'))

        input_gamma = fbn.in_node(1)
        input_beta = fbn.in_node(2)

        mean_reduction = match['mean'].in_node(1)
        variance_reduction = match['variance'].in_node(1)

        new_subgraph = add.create_node([
            mul.create_node([
                mvn.create_node([input, mean_reduction, variance_reduction]),
                input_gamma
            ]),
            input_beta
        ])

        replace_node(fbn, new_subgraph)
Ejemplo n.º 8
0
    def test_replace_node_several_consumers(self):
        graph = build_graph(
            {
                'input_1': {
                    'type': 'Placeholder',
                    'value': None,
                    'kind': 'op'
                },
                'input_2': {
                    'type': 'Placeholder',
                    'value': None,
                    'kind': 'op'
                },
                'old': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_1': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_2': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_3': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
            }, [
                ('input_1', 'old'),
                ('input_2', 'old'),
                ('old', 'output_3'),
                ('old', 'output_2'),
                ('old', 'output_1'),
            ])

        new_node = Const(graph, {
            'name': 'new'
        }).create_node([Node(graph, 'input_1'),
                        Node(graph, 'input_2')])
        replace_node(Node(graph, 'old'), new_node)

        self.assertEqual(len(graph.nodes()), 6)
        self.assertEqual(len(graph.edges()), 5)
        self.assertListEqual(sorted(graph.out_edges('new')),
                             [('new', 'output_1'), ('new', 'output_2'),
                              ('new', 'output_3')])
        expected_result = [('new', 'output_1', {
            'in': 0,
            'out': 2,
            'name': 'old'
        }), ('new', 'output_2', {
            'in': 0,
            'out': 1,
            'name': 'old'
        }), ('new', 'output_3', {
            'in': 0,
            'out': 0,
            'name': 'old'
        })]
        self.assertListEqual(sorted(graph.out_edges('new', data=True)),
                             expected_result)
Ejemplo n.º 9
0
    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):

        # node that is used to identify this pattern application instance for switching between supported
        # and not supported LSTMCell sub-graphs; this value will be searched in __class__.instances_supported_by_IE.
        anchor_node = match[__class__.anchor()]
        assert anchor_node.has_valid('name'), \
            'LSTMCell anchor node {} does\'t have attribute name; such nodes are not supported.'

        if __class__.second_round and anchor_node.name not in __class__.instances_supported_by_IE:
            # at the second round of conversion we apply pattern selectively: only instances from
            # __class__.instances_supported_by_IE are allowed for conversion; all others should be skipped
            return

        match['input_op'] = match['concat'].in_node(0)
        match['input_hidden_state'] = match['concat'].in_node(1)
        match['input_cell_state'] = match['mul_0'].in_node(0) if match['mul_0'].in_node(0).id != match['sigmoid_0'].id \
            else match['mul_0'].in_node(1)

        pattern_edges = self.pattern()['edges']
        pattern_edges.extend([('input_op', 'concat'),
                              ('input_cell_state', 'mul_0'),
                              ('input_hidden_state', 'concat')])
        inputs = get_inputs_with_ports(
            graph, match, pattern_edges,
            __class__.inputs + __class__.extra_inputs)

        lstm_op = LSTMCell(
            graph,
            dict(
                name=match['concat'].name + '/LSTMCell',
                mark_supported_by_IE=__class__.mark_supported_by_IE,
                original_name=anchor_node.name,
                finalize_first_round=__class__.finalize_first_round,
            ))
        lstm_node = lstm_op.create_node(inputs)
        lstm_node['old_infer'] = lstm_node.infer
        lstm_node.infer = __class__.infer

        # this node consumes one of the resulting LSTMCell outputs,
        # it should be removed before reconnecting the nodes,
        # otherwise it will be reconnected to the new cell output
        graph.remove_node(match['tanh_1'].id)

        for i, output in enumerate(__class__.outputs):
            replace_node(match[output], lstm_node, i)

        # Because of LSTMCell specification, this layer MUST have 2 outputs.
        # => we need to create fake consumers for LSTMCell
        # when this node haven't some outputs.
        for i in [0, 1]:
            if i not in lstm_node.out_nodes():
                fake_output_node = Output(
                    graph, dict(name=lstm_node.name + "/Output_{}".format(i)))
                fake_output_node.create_node(inputs=[lstm_node],
                                             edge_attrs={
                                                 'out': i,
                                                 'in': 0
                                             })

        lstm_node['tf'] = True
        lstm_node['extra_inputs'] = {
            name: match[name].id
            for name in __class__.extra_inputs
        }
        lstm_node['inputs'] = {
            name: match[name].id
            for name in __class__.inputs
        }