コード例 #1
0
    def replace_pattern(graph: Graph, match: dict):
        log.debug('================== SimpletInputFind ===============')

        input_node = TensorIteratorInput(graph, dict(external_port_id=None,
                                                     internal_layer_id=None,
                                                     name=match['Enter'].name + '/TensorIteratorInput_'
                                                     ))
        input_node.create_node_with_data(inputs=[match['Enter'].in_node()], data_nodes=[match['Enter'].out_node()])

        # Delete useless nodes
        graph.remove_nodes_from([match['Enter'].id])
コード例 #2
0
    def replace_pattern(graph: Graph, match: dict):
        log.debug('================== SimpleBackEdgeInputFind ===============')

        assert len(match['BackEdge'].in_nodes()) == 3
        condition = match['BackEdge'].in_node(2)
        init_input = match['BackEdge'].in_node(0)
        cycle_input = match['BackEdge'].in_node(1)

        # We need to create new TensorItertorInput node only if this node doesn't exist already.
        if len(init_input.in_nodes()) == 0 or\
           (len(init_input.in_nodes()) == 1 and init_input.has_valid('value')):

            input_node = TensorIteratorInput(
                graph,
                dict(external_port_id=None,
                     internal_layer_id=None,
                     name=match['BackEdge'].name + '/TensorIteratorInput_'))

            # In case if data node has Constant producer
            if len(init_input.in_nodes()) == 1:
                graph.remove_edge(init_input.in_node(0).id, init_input.id)

            input_data_node = input_node.create_node_with_data(
                inputs=[init_input])
            input_data_node.shape = np.array(init_input.shape, dtype=np.int64)
            graph.remove_edges_from([(init_input.id, match['BackEdge'].id)])
            graph.add_edges_from([(input_data_node.id, match['BackEdge'].id, {
                'in': 0,
                'out': 0
            })])
コード例 #3
0
    def replace_pattern(graph: Graph, match: dict):
        log.debug('================== SmartInputFind ===============')

        assert match['Enter_data'].value is not None
        assert match['stack_data']['value'][0] == 0 and match['stack_1_data']['value'][0] == 1 and \
               match['stack_2_data']['value'][0] == 1
        assert match['start_data']['value'] == 0 and match['delta_data'][
            'value'] == 1

        ta_size_data = match['TensorArray'].in_node()
        ta_size = ta_size_data.in_node()
        value = match['TensorArrayScatter'].in_node(2)

        start, end = None, None
        if 0 in ta_size.in_nodes():
            shape = match['StridedSlice'].in_node(0).in_node(0)
            # Case when value for Strided slice is Const, not Shape
            if shape['kind'] == 'op' and shape['op'] == 'Const':
                start = 0
                end = shape.value[0]
                log.warning(
                    "You network cannot be reshaped since shapes of placeholders is a contants."
                    "Please, provide non-constant shapes. ")

        # Create input node with params
        # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
        # condition)
        input_node = TensorIteratorInput(
            graph,
            dict(axis=0,
                 start=start,
                 stride=None,
                 part_size=None,
                 external_port_id=str(match['Enter_data'].value),
                 internal_layer_id=match['TensorArrayRead_data'].id,
                 name=match['TensorArrayRead'].name + '/TensorIteratorInput_'))
        input_node.create_node_with_data(
            inputs=[ta_size_data, value, match['Condition_data']],
            data_nodes=[match['TensorArrayRead_data']])
        # Delete useless nodes
        safe_nodes = ['TensorArrayRead_data', 'Condition', 'Condition_data']

        nodes_for_remove = []
        for node in match.keys():
            if node not in safe_nodes:
                nodes_for_remove.append(match[node].id)
        graph.remove_nodes_from(nodes_for_remove)