Beispiel #1
0
    def extract(cls, node):
        pb = node.parameters
        mapping_rule = {'context': list()}
        tag = find_next_tag(pb)
        if tag == '<LeftContext>':
            read_placeholder(pb, 1)
            l_context = read_binary_integer32_token(pb)
            tag = find_next_tag(pb)
            if tag != '<RightContext>':
                raise Error(
                    'Unknown token {} in SpliceComponent node {}'.format(
                        tag, node.id))
            read_placeholder(pb, 1)
            r_context = read_binary_integer32_token(pb)
            for i in range(-l_context, r_context + 1):
                mapping_rule['context'].append(i)
        elif tag == '<Context>':
            collect_until_whitespace(pb)
            mapping_rule['context'] = read_binary_vector(pb,
                                                         False,
                                                         dtype=np.int32)
        else:
            raise Error('Unknown token {} in SpliceComponent node {}'.format(
                tag, node.id))

        tag = find_next_tag(pb)
        if tag == '<ConstComponentDim>':
            read_placeholder(pb, 1)
            const_dim = read_binary_integer32_token(pb)
            mapping_rule['const_dim'] = const_dim

        Splice.update_node_stat(node, mapping_rule)
        return cls.enabled
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        pair_node = Node(graph, node.pair_name)

        if pair_node.has_default:
            return

        if node.in_port(0).get_source() is not None:
            input_node_out_port = node.in_port(0).get_source()
            op_output_id = node.out_port(0).get_destination().node.id
            out_node_in_ports = pair_node.out_port(0).get_destinations()
        else:
            input_node_out_port = pair_node.in_port(0).get_source()
            op_output_id = pair_node.out_port(0).get_destination().node.id
            out_node_in_ports = node.out_port(0).get_destinations()

        in_shape = input_node_out_port.data.get_shape().copy()

        node_id = node.id
        node_name = node.name
        node_t = node.t

        splice = Splice(graph, {'name': node_name,
                                'id': node_id,
                                'context': int64_array(range(node_t, 1))
                                if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
        splice.in_port(0).connect(input_node_out_port)

        # offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0
        crop = Crop(graph, {'name': 'Splice_Crop',
                            'axis': int64_array([1]),
                            'offset': int64_array([max(0, in_shape[1] * node_t)]),
                            'dim': int64_array([in_shape[1]])}).create_node()

        splice.out_port(0).connect(crop.in_port(0))
        splice.out_port(0).data.set_shape(int64_array([in_shape[0], (abs(node_t) + 1) * in_shape[1]]))

        outs = input_node_out_port.get_destinations()
        for in_port in outs:
            out_ = in_port.node
            if out_['op'] != 'MemoryOffset' and out_['op'] != 'Splice':
                crop_input = Crop(graph, {'name': 'Splice_Crop',
                                          'axis': int64_array([1]),
                                          'offset': int64_array([-min(0, in_shape[1] * node_t)]),
                                          'dim': int64_array([in_shape[1]])}).create_node()
                splice.out_port(0).connect(crop_input.in_port(0))

                in_port.disconnect()
                crop_input.out_port(0).connect(in_port)
                crop_input.out_port(0).data.set_shape(in_shape)

        for dest_port in out_node_in_ports:
            dest_port.connect(crop.out_port(0))

        graph.remove_node(op_output_id)
        graph.remove_node(node.id)
        graph.remove_node(pair_node.id)