Пример #1
0
    def extract(cls, node):
        pb = node.parameters
        mapping_rule = {
            'context': list()
        }
        tag = find_next_tag(pb)
        if tag == '<LeftContext>':
            read_placeholder(pb, 1)
            l_context = read_binary_integer32_token(pb)
            tag = find_next_tag(pb)
            if tag != '<RightContext>':
                raise Error('Unknown token {} in SpliceComponent node {}'.format(tag, node.id))
            read_placeholder(pb, 1)
            r_context = read_binary_integer32_token(pb)
            for i in range(-l_context, r_context + 1):
                mapping_rule['context'].append(i)
        elif tag == '<Context>':
            collect_until_whitespace(pb)
            mapping_rule['context'] = read_binary_vector(pb, False, dtype=np.int32)
        else:
            raise Error('Unknown token {} in SpliceComponent node {}'.format(tag, node.id))

        tag = find_next_tag(pb)
        if tag == '<ConstComponentDim>':
            read_placeholder(pb, 1)
            const_dim = read_binary_integer32_token(pb)
            mapping_rule['const_dim'] = const_dim

        Splice.update_node_stat(node, mapping_rule)
        return cls.enabled
Пример #2
0
    def extract(cls, node: Node) -> bool:
        """
        Extract conv parameters from node.parameters.
        node.parameters like file descriptor object.
        :param node: Convolution node
        :return:
        """
        pb = node.parameters
        kernel = read_token_value(pb, b'<PatchDim>')
        stride = read_token_value(pb, b'<PatchStep>')
        patch_stride = read_token_value(pb, b'<PatchStride>')

        read_learning_info(pb)

        collect_until_whitespace(pb)
        weights, weights_shape = read_binary_matrix(pb)

        collect_until_whitespace(pb)
        biases = read_binary_vector(pb)

        if (patch_stride - kernel) % stride != 0:
            raise Error(
                'Kernel size and stride does not correspond to `patch_stride` attribute of Convolution layer. ' +
                refer_to_faq_msg(93))

        output = biases.shape[0]
        if weights_shape[0] != output:
            raise Error('Weights shape does not correspond to the `output` attribute of Convolution layer. ' +
                        refer_to_faq_msg(93))

        mapping_rule = {
            'output': output,
            'patch_stride': patch_stride,
            'bias_term': None,
            'pad': np.array([[0, 0], [0, 0], [0, 0], [0, 0]], dtype=np.int64),
            'pad_spatial_shape': np.array([[0, 0], [0, 0]], dtype=np.int64),
            'dilation': np.array([1, 1, 1, 1], dtype=np.int64),
            'kernel': np.array([1, 1, 1, kernel], dtype=np.int64),
            'stride': np.array([1, 1, 1, stride], dtype=np.int64),
            'kernel_spatial': np.array([1, kernel], dtype=np.int64),
            'input_feature_channel': 1,
            'output_feature_channel': 0,
            'kernel_spatial_idx': [2, 3],
            'group': 1,
            'reshape_kernel': True,
        }

        mapping_rule.update(layout_attrs())
        embed_input(mapping_rule, 1, 'weights', weights)
        embed_input(mapping_rule, 2, 'biases', biases)

        mapping_rule['bias_addable'] = len(biases) > 0

        Convolution.update_node_stat(node, mapping_rule)
        return cls.enabled
    def extract(cls, node):
        clip_value = 50
        pb = node.parameters
        res = collect_until_whitespace(pb)
        if res == b'<CellClip>':
            clip_value = get_uint32(pb.read(4))
        collect_until_token(pb, b'FM')
        gifo_x_weights, gifo_x_weights_shape = read_binary_matrix(pb, False)
        gifo_r_weights, gifo_r_weights_shape = read_binary_matrix(pb)
        gifo_biases = read_binary_vector(pb)
        input_gate_weights = read_binary_vector(pb)
        forget_gate_weights = read_binary_vector(pb)
        output_gate_weights = read_binary_vector(pb)

        projection_weights, projection_weights_shape = read_binary_matrix(pb)

        mapping_rule = {
            'gifo_x_weights_shape': gifo_x_weights_shape,
            'gifo_r_weights_shape': gifo_r_weights_shape,
            'projection_weights_shape': projection_weights_shape,
            'clip_value': clip_value,
            'format': 'kaldi',
        }

        embed_input(mapping_rule, 1, 'gifo_x_weights', gifo_x_weights)
        embed_input(mapping_rule, 2, 'gifo_r_weights', gifo_r_weights)
        embed_input(mapping_rule, 3, 'gifo_biases', gifo_biases)
        embed_input(mapping_rule, 4, 'input_gate_weights', input_gate_weights)
        embed_input(mapping_rule, 5, 'forget_gate_weights',
                    forget_gate_weights)
        embed_input(mapping_rule, 6, 'output_gate_weights',
                    output_gate_weights)
        embed_input(mapping_rule, 7, 'projection_weights', projection_weights)

        LSTMCell.update_node_stat(node, mapping_rule)
        return cls.enabled