Пример #1
0
def read_binary_vector(file_desc: io.BufferedReader,
                       read_token: bool = True,
                       dtype=np.float32):
    if read_token:
        read_placeholder(file_desc)
    elements_number = read_binary_integer32_token(file_desc)
    return read_blob(file_desc, elements_number, dtype)
Пример #2
0
def read_learning_info(pb: io.BufferedReader):
    while True:
        read_placeholder(pb, 1)
        first_char = pb.read(1)
        pb.seek(-2, os.SEEK_CUR)
        position = pb.tell()
        if first_char == b'L':
            cur_pos = pb.tell()
            token = find_next_tag(pb)
            pb.seek(cur_pos)
            if token in ['<LearnRateCoef>', '<LearningRate>']:
                token = bytes(token, 'ascii')
            else:
                log.debug('Unexpected tag: {}'.format(token))
                break
        elif first_char == b'B':
            token = b'<BiasLearnRateCoef>'
        elif first_char == b'M':
            token = b'<MaxNorm>'
        elif first_char == b'!':  # token = b'<EndOfComponent>'
            break
        else:
            break
        try:
            read_token_value(pb, token)
        except Error:
            pb.seek(position)
            break
Пример #3
0
    def extract(cls, node):
        pb = node.parameters
        collect_until_token(pb, b'<PoolSize>')
        kernel = read_binary_integer32_token(pb)
        tag = find_next_tag(pb)
        if tag == '<PoolStep>':
            read_placeholder(pb, 1)
            stride = read_binary_integer32_token(pb)
            pool_step = stride
            pool_stride = read_token_value(pb, b'<PoolStride>')
        elif tag == '<PoolStride>':
            stride = 1
            pool_step = None
            read_placeholder(pb, 1)
            pool_stride = read_binary_integer32_token(pb)
        else:
            raise Error('Can not extract parameters for {}'.format(node))

        mapping_rule = {
            'window': np.array([1, 1, 1, kernel], dtype=np.int64),
            'stride': np.array([1, 1, stride, stride], dtype=np.int64),
            'pool_stride': pool_stride,
            'pool_step': pool_step,
            'pad': np.array([[0, 0], [0, 0], [0, 0], [0, 0]], dtype=np.int64),
            'pad_spatial_shape': np.array([[0, 0], [0, 0]], dtype=np.int64),
            'pool_method': 'max',
        }
        mapping_rule.update(layout_attrs())
        Pooling.update_node_stat(node, mapping_rule)
        return cls.enabled
Пример #4
0
def load_kaldi_model(nnet_path):
    """
    Structure of the file is the following:
    magic-number(16896)<Nnet> <Next Layer Name> weights etc.
    :param nnet_path:
    :return:
    """
    nnet_name = None
    if isinstance(nnet_path, str):
        file_desc = open(nnet_path, "rb")
        nnet_name = get_name_from_path(nnet_path)
    elif isinstance(nnet_path, IOBase):
        file_desc = nnet_path
    else:
        raise Error('Unsupported type of Kaldi model')

    name = find_next_tag(file_desc)
    # start new model / submodel
    if name == '<Nnet>':
        load_function = load_kalid_nnet1_model
    elif name == '<TransitionModel>':
        load_function = load_kalid_nnet2_model
    else:
        raise Error(
            'Kaldi model should start with <Nnet> or <TransitionModel> tag. ',
            refer_to_faq_msg(89))
    read_placeholder(file_desc, 1)

    return load_function(file_desc, nnet_name)
Пример #5
0
 def extract(node):
     pb = node.parameters
     mapping_rule = {'context': list()}
     tag = find_next_tag(pb)
     if tag == '<LeftContext>':
         read_placeholder(pb, 1)
         l_context = read_binary_integer32_token(pb)
         tag = find_next_tag(pb)
         if tag != '<RightContext>':
             raise Error(
                 'Unknown token {} in SpliceComponent node {}'.format(
                     tag, node.id))
         read_placeholder(pb, 1)
         r_context = read_binary_integer32_token(pb)
         for i in range(-l_context, r_context + 1):
             mapping_rule['context'].append(i)
     elif tag == '<Context>':
         collect_until_whitespace(pb)
         mapping_rule['context'] = read_binary_vector(pb,
                                                      False,
                                                      dtype=np.int32)
     else:
         raise Error('Unknown token {} in SpliceComponent node {}'.format(
             tag, node.id))
     Splice.update_node_stat(node, mapping_rule)
     return __class__.enabled
Пример #6
0
def read_binary_matrix(file_desc: io.BufferedReader, read_token: bool = True):
    if read_token:
        read_placeholder(file_desc)
    rows_number = read_binary_integer32_token(file_desc)
    cols_number = read_binary_integer32_token(file_desc)
    # to compare: ((float *)a->buffer())[10]
    return read_blob(file_desc,
                     rows_number * cols_number), (rows_number, cols_number)
Пример #7
0
    def extract(cls, node):
        pb = node.parameters
        collect_until_token(pb, b'<Params>')
        weights = read_binary_vector(pb)
        find_next_tag(pb)
        read_placeholder(pb, 1)

        mapping_rule = {'layout': 'NCHW'}
        embed_input(mapping_rule, 1, 'weights', weights)

        ScaleShiftOp.update_node_stat(node, mapping_rule)
        return cls.enabled
    def extract(cls, node):
        pb = node.parameters
        collect_until_token(pb, b'<Bias>')
        biases = read_binary_vector(pb)
        find_next_tag(pb)
        read_placeholder(pb, 1)

        mapping_rule = {
            'layout': 'NCHW',
            'bias_term': True,
            'out-size': biases.shape[0],
        }
        embed_input(mapping_rule, 2, 'biases', biases)

        ScaleShiftOp.update_node_stat(node, mapping_rule)
        return cls.enabled
    def extract(node):
        pb = node.parameters
        collect_until_token(pb, b'<LinearParams>')
        weights, weights_shape = read_binary_matrix(pb)
        tag = find_next_tag(pb)
        read_placeholder(pb, 1)
        if tag != '<BiasParams>':
            raise Error('FixedAffineComponent must contain BiasParams')
        biases = read_binary_vector(pb)

        mapping_rule = {'out-size': weights_shape[0], 'layout': 'NCHW'}
        embed_input(mapping_rule, 1, 'weights', weights)
        embed_input(mapping_rule, 2, 'biases', biases)

        InnerProduct.update_node_stat(node, mapping_rule)
        return __class__.enabled
Пример #10
0
    def extract(cls, node):
        pb = node.parameters
        collect_until_token(pb, b'<LinearParams>')
        weights, weights_shape = read_binary_matrix(pb)
        tag = find_next_tag(pb)
        read_placeholder(pb, 1)
        if tag != '<BiasParams>':
            raise Error('FixedAffineComponent must contain BiasParams')
        biases = read_binary_vector(pb)

        mapping_rule = {
            'out-size': weights_shape[0],
            'transpose_weights': True,
        }
        embed_input(mapping_rule, 1, 'weights', weights)
        embed_input(mapping_rule, 2, 'biases', biases)

        FullyConnected.update_node_stat(node, mapping_rule)
        return cls.enabled