예제 #1
0
def get_char_model(char_layer,
                   max_word_length: int,
                   vocabulary: WordVocabulary,
                   char_set: str,
                   embeddings: np.array,
                   model_weights_path: str,
                   model_config_path: str,
                   batch_size: int = 128,
                   val_part: float = 0.2,
                   seed: int = 42):
    """
    Обучение или загрузка char-level функции.

    :param char_layer: заданная char-level функция, которую и обучаем.
    :param max_word_length: максимальная длина слова, по которой идёт обрезка.
    :param vocabulary: список слов.
    :param char_set: набор символов, для которых строятся эмбеддинги.
    :param embeddings: матрица эмбеддингов.
    :param batch_size: размер батча.
    :param model_weights_path: путь, куда сохранять веса модели.
    :param model_config_path: путь, куда сохранять конфиг модели.
    :param val_part: доля val выборки.
    :param seed: seed для ГПСЧ.
    """
    model = CharEmbeddingsModel()
    if model_config_path is not None and os.path.exists(model_config_path):
        assert model_weights_path is not None and os.path.exists(
            model_weights_path)
        model.load(model_config_path, model_weights_path)
    else:
        vocabulary = copy.deepcopy(vocabulary)
        vocabulary.shrink(embeddings.shape[0])
        model.build(vocabulary_size=vocabulary.size(),
                    word_embeddings_dimension=embeddings.shape[1],
                    max_word_length=max_word_length,
                    word_embeddings=embeddings.T,
                    char_layer=char_layer)
        model.train(vocabulary, char_set, val_part, seed, batch_size,
                    max_word_length)
        if model_config_path is not None and model_weights_path is not None:
            model.save(model_config_path, model_weights_path)
    return model.char_layer
예제 #2
0
def load_embeddings(embeddings_file_name: str, vocabulary: WordVocabulary,
                    word_count: int):
    with open(embeddings_file_name, "r", encoding='utf-8') as f:
        line = next(f)
        dimension = int(line.strip().split()[1])
        matrix = np.random.rand(min(vocabulary.size(), word_count + 1),
                                dimension) * 0.05
        words = {
            word: i
            for i, word in enumerate(vocabulary.words[:word_count])
        }
        for line in f:
            try:
                word = line.strip().split()[0]
                embedding = [float(i) for i in line.strip().split()[1:]]
                index = words.get(word)
                if index is not None:
                    matrix[index] = embedding
            except ValueError or UnicodeDecodeError:
                continue
        return matrix