Пример #1
0
    def replace_pattern(graph: Graph, match: dict):
        log.debug('================== SimpleBackEdgeInputFind ===============')

        assert len(match['BackEdge'].in_nodes()) == 3
        condition = match['BackEdge'].in_node(2)
        init_input = match['BackEdge'].in_node(0)
        cycle_input = match['BackEdge'].in_node(1)

        # We need to create new TensorItertorInput node only if this node doesn't exist already.
        if len(init_input.in_nodes()) == 0 or\
           (len(init_input.in_nodes()) == 1 and init_input.has_valid('value')):

            input_node = TensorIteratorInput(
                graph,
                dict(external_port_id=None,
                     internal_layer_id=None,
                     name=match['BackEdge'].name + '/TensorIteratorInput_'))

            # In case if data node has Constant producer
            if len(init_input.in_nodes()) == 1:
                graph.remove_edge(init_input.in_node(0).id, init_input.id)

            input_data_node = input_node.create_node_with_data(
                inputs=[init_input])
            input_data_node.shape = np.array(init_input.shape, dtype=np.int64)
            graph.remove_edges_from([(init_input.id, match['BackEdge'].id)])
            graph.add_edges_from([(input_data_node.id, match['BackEdge'].id, {
                'in': 0,
                'out': 0
            })])
    def replace_pattern(graph: Graph, match: dict):
        log.debug('================== SimpletInputFind ===============')

        input_node = TensorIteratorInput(graph, dict(external_port_id=None,
                                                     internal_layer_id=None,
                                                     name=match['Enter'].name + '/TensorIteratorInput_'
                                                     ))
        input_node.create_node_with_data(inputs=[match['Enter'].in_node()], data_nodes=[match['Enter'].out_node()])

        # Delete useless nodes
        graph.remove_nodes_from([match['Enter'].id])
Пример #3
0
    def replace_pattern(graph: Graph, match: dict):
        log.debug('================== SmartInputFind ===============')

        assert match['Enter_data'].value is not None
        assert match['stack_data']['value'][0] == 0 and match['stack_1_data']['value'][0] == 1 and \
               match['stack_2_data']['value'][0] == 1
        assert match['start_data']['value'] == 0 and match['delta_data'][
            'value'] == 1

        ta_size_data = match['TensorArray'].in_node()
        ta_size = ta_size_data.in_node()
        value = match['TensorArrayScatter'].in_node(2)

        start, end = None, None
        if 0 in ta_size.in_nodes():
            shape = match['StridedSlice'].in_node(0).in_node(0)
            # Case when value for Strided slice is Const, not Shape
            if shape['kind'] == 'op' and shape['op'] == 'Const':
                start = 0
                end = shape.value[0]
                log.warning(
                    "You network cannot be reshaped since shapes of placeholders is a contants."
                    "Please, provide non-constant shapes. ")

        # Create input node with params
        # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
        # condition)
        input_node = TensorIteratorInput(
            graph,
            dict(axis=0,
                 start=start,
                 stride=None,
                 part_size=None,
                 external_port_id=str(match['Enter_data'].value),
                 internal_layer_id=match['TensorArrayRead_data'].id,
                 name=match['TensorArrayRead'].name + '/TensorIteratorInput_'))
        input_node.create_node_with_data(
            inputs=[ta_size_data, value, match['Condition_data']],
            data_nodes=[match['TensorArrayRead_data']])
        # Delete useless nodes
        safe_nodes = ['TensorArrayRead_data', 'Condition', 'Condition_data']

        nodes_for_remove = []
        for node in match.keys():
            if node not in safe_nodes:
                nodes_for_remove.append(match[node].id)
        graph.remove_nodes_from(nodes_for_remove)
Пример #4
0
    def replace_pattern(graph: Graph, match: dict):
        # retrieve attribute values for TensorIteratorInput node
        init_time = match['InitIndex'].value.item(0)
        time_step = match['IndexDelta'].value.item(0)
        axis = match['Axis'].value.item(0)

        # retrieve input and output nodes for TensorIteratorInput node
        initial_input_node = match['EnterInput']
        current_index_node = match['IdentityIndex']
        size_node = match['EnterMaxIndex']
        resulted_slice_node = match['SqueezeSlice']
        resulted_slice_node_name = resulted_slice_node.soft_get('name', resulted_slice_node.id)

        # create TensorIteratorInput node that reflects slicing of input for each time step along axis
        ti_input_node = TensorIteratorInput(graph, dict(axis=axis, start=init_time, stride=time_step,
                                                        name=resulted_slice_node_name + '/TensorIteratorInput')
                                            ).create_node()
        size_node.in_port(0).get_connection().add_destination(ti_input_node.in_port(0))
        initial_input_node.in_port(0).get_connection().set_destination(ti_input_node.in_port(1))
        current_index_node.out_port(0).connect(ti_input_node.in_port(2))
        resulted_slice_node.out_port(0).get_connection().set_source(ti_input_node.out_port(0))

        # delete no longer needed nodes responsible for slicing of input in the original graph
        node_names_for_remove = ['EnterInput', 'MergeInput', 'SwitchInput',
                                 'IdentityInput', 'NextIterationInput', 'SqueezeSlice', 'UnsqueezeIndex', 'Gather']
        graph.remove_nodes_from([match[node_name].id for node_name in node_names_for_remove])