Esempio n. 1
0
    def replace_sub_graph(self, graph: Graph, external_match: dict):
        loop_node = external_match['while']
        body_graph = loop_node['body']
        body_pattern = KerasRNNOutputConcatenation.get_body_pattern()

        internal_matches = find_subgraph_match_to_pattern(
            body_graph, body_pattern)

        if len(internal_matches) == 1:
            internal_match = internal_matches[0]
            reserve_port_idx = compute_input_port_idx(
                external_match['reserve'], loop_node)
            stack_port_idx = external_match['stack'].in_port(
                0).get_source().idx
            # check that back edges connect correct Parameter and Result nodes in the body
            # check connections between body input ports and external inputs ports of Loop node
            # check connections between body output ports and external output ports of Loop node
            if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
                                     internal_match['container'].internal_layer_id) and \
                    Loop.back_edge_exists(loop_node.back_edges,
                                          internal_match['increment_iteration_result'].internal_layer_id,
                                          internal_match['current_iteration'].internal_layer_id) and \
                    Loop.inter_edge_exists(loop_node.input_port_map, reserve_port_idx,
                                           internal_match['container'].internal_layer_id) and \
                    Loop.inter_edge_exists(loop_node.output_port_map, stack_port_idx,
                                           internal_match['concatenation_result'].internal_layer_id):
                KerasRNNOutputConcatenation.transform_keras_rnn_output_concatenation(
                    external_match, internal_match)
Esempio n. 2
0
    def replace_sub_graph(self, graph: Graph, external_match: dict):
        loop_node = external_match['while']
        body_graph = loop_node['body']
        body_pattern = KerasRNNInputSlicing.get_body_pattern()
        internal_matches = find_subgraph_match_to_pattern(
            body_graph, body_pattern)

        # a case of multiple matches is not handled since it is not clear how to select corresponding match
        if len(internal_matches) == 1:
            internal_match = internal_matches[0]
            loop_node = external_match['while']
            unstack_port_idx = compute_input_port_idx(
                external_match['unstack'], loop_node)
            # check that back edges connect correct Parameter and Result nodes in the body
            # check connections between body input ports and external inputs ports of Loop node
            if Loop.back_edge_exists(loop_node.back_edges,
                                     internal_match['increment_iteration_result'].internal_layer_id,
                                     internal_match['current_iteration'].internal_layer_id) and \
                    Loop.inter_edge_exists(loop_node.input_port_map, unstack_port_idx,
                                           internal_match['tensor_list'].internal_layer_id):
                # only if inter-graph match passed it starts to process the sub-graph
                KerasRNNInputSlicing.transform_keras_rnn_input_slicing(
                    external_match, internal_match)