def preprocess(config_path,
               file_path,
               save_path,
               bert_path,
               max_src_tokens,
               max_tgt_tokens,
               lower=False,
               nrows=None):
    bert = BertData(bert_path, lower, max_src_tokens, max_tgt_tokens)
    params = Params.from_file(config_path)
    reader_params = params.pop("reader", default=Params({}))
    reader = DatasetReader.from_params(reader_params)
    data = []
    for i, (text, summary) in enumerate(reader.parse_set(file_path)):
        if nrows is not None and i >= nrows:
            break
        src = [(s.text.lower() if lower else s.text).split()
               for s in sentenize(text)]
        tgt = [(s.text.lower() if lower else s.text).split()
               for s in sentenize(summary)]
        src_indices, tgt_indices, segments_ids, cls_ids, src_txt, tgt_txt = bert.preprocess(
            src, tgt)
        b_data_dict = {
            "src": src_indices,
            "tgt": tgt_indices,
            "segs": segments_ids,
            'clss': cls_ids,
            'src_txt': src_txt,
            "tgt_txt": tgt_txt
        }
        data.append(b_data_dict)
    torch.save(data, save_path)
예제 #2
0
    def text_to_instance(self,
                         text: str,
                         sentences: List[str] = None,
                         tags: List[int] = None) -> Instance:
        if sentences is None:
            if self._language == "ru":
                sentences = [s.text for s in razdel.sentenize(text)]
            else:
                sentences = nltk.tokenize.sent_tokenize(text)
        sentences_tokens = []
        for sentence in sentences[:self._max_sentences_count]:
            sentence = sentence.lower() if self._lowercase else sentence
            tokens = self._tokenizer.tokenize(
                sentence)[:self._sentence_max_tokens]
            tokens.insert(0, Token(START_SYMBOL))
            tokens.append(Token(END_SYMBOL))
            indexed_tokens = TextField(tokens, self._source_token_indexers)
            sentences_tokens.append(indexed_tokens)

        sentences_tokens_indexed = ListField(sentences_tokens)
        result = {'source_sentences': sentences_tokens_indexed}

        if tags:
            result["sentences_tags"] = SequenceLabelField(
                tags[:self._max_sentences_count], sentences_tokens_indexed)
        return Instance(result)
예제 #3
0
def ru_tokenizer(text: str) -> list:
    """
    Tokenizes texts in Russian
    Args:
        text (str): input text

    Returns: 
        flair Token objects
    """
    all_sentences = []
    for paragraph in split_newline(text):
        sentences = [x.text for x in list(sentenize(paragraph))]
        all_sentences.extend(sentences)
    words = []
    for sentence in all_sentences:
        sentence_tokens = [x.text for x in list(tokenize(sentence))]
        words.extend(sentence_tokens)
    prev_start_position = 0
    tokens = []
    for word in words:
        start_position = text[prev_start_position:].index(word)
        token = Token(text=word,
                      start_position=prev_start_position + start_position,
                      whitespace_after=False)
        tokens.append(token)
        prev_start_position = start_position + prev_start_position + len(word)
    return tokens
예제 #4
0
    def __call__(self, text: str):
        """Performs tokenization and sentence splitting.

        Args:
            text(str): text.

        Returns:
            Dictionary that contains:
            1. tokens - list of objects Token.
            2. sentences - list of objects Sentences.
        """

        ann_tokens = [
            ann.Token(text=token.text, begin=token.start, end=token.stop)
            for token in razdel.tokenize(text)
        ]

        sentences = [
            ProcessorRazdel.offset_to_tokens(offset.start, offset.stop,
                                             ann_tokens)
            for offset in razdel.sentenize(text)
        ]
        ann_sentences = [ann.Sentence(begin, end) for begin, end in sentences]

        return {'tokens': ann_tokens, 'sentences': ann_sentences}
예제 #5
0
def skill_title(doc):
    title = doc['name']
    description_long = doc['description']
    sents = list(razdel.sentenize(description_long))
    description_short = sents[0].text
    if jaccard(title, description_short) < 0.5:
        return '{} ({})'.format(title, description_short)
    return title
예제 #6
0
 def sents(self, fileids=None, categories=None):
     """
     Функция возвращает каждую новость в виде отдельных предложений,
     для всех новостей из обрабатываемых файлов
     """
     for sentence in self.listcolumns(fileids, categories):
         sents = list(sentenize(sentence))
         yield([_.text for _ in sents])
예제 #7
0
def process_text_file(text_file, mongo=None):
    # nlp = spacy.load('ru_core_news_sm')
    segmenter = Segmenter()
    emb = NewsEmbedding()
    morph_tagger = NewsMorphTagger(emb)
    syntax_parser = NewsSyntaxParser(emb)

    with open(text_file, 'r', encoding='utf-8') as file:
        file_name = file.name[2:]
        line_number = 0
        for line in file:
            line_number += 1
            if line_number % 100 == 0:
                logging.info(f'Processed line {line_number}')
                if line_number >= 100000:
                    return
            sents = [sent.text for sent in sentenize(line)]
            sentence_number = 0
            for sentence in sents:
                doc = Doc(sentence)
                doc.segment(segmenter)
                doc.tag_morph(morph_tagger)
                doc.parse_syntax(syntax_parser)
                sentence_number += 1
                sentence_tokens = doc.tokens

                # sentence_tokens = [
                #     {
                #         'text': token.text,
                #         'lemma': token.lemma_,
                #         'pos': token.pos_,
                #         'tag': token.tag_,
                #         'dep': token.dep_,
                #         'shape': token.shape_,
                #         'is_alpha': token.is_alpha,
                #         'is_stop': token.is_stop
                #     } for token in sentence]
                words = markup_words(doc.syntax)
                deps = token_deps(doc.syntax.tokens)
                html = show_dep_markup(words, deps)
                save_html(
                    html,
                    f'./htmls/dependency_plot_{file_name}_{line_number}_{sentence_number}.html'
                )
                #
                # svg = displacy.render(sentence, style='dep', options={'compact': False, 'bg': '#09a3d5',
                #                                                       'color': 'white', 'font': 'Source Sans Pro'})
                # output_path = Path(f'./images/dependency_plot_{file_name}_{line_number}_{sentence_number}.svg')
                # output_path.open('w', encoding='utf-8').write(svg)
                PatternExtractor.extract_relations(
                    file_name,
                    line_number,
                    sentence_number,
                    sentence,
                    sentence_tokens,
                    # noun_phrases,
                    # mongo=mongo
                )
예제 #8
0
 def get_prediction_lex_rank(self, text, summary_size=1, threshold=None):
     try:
         sentences = [s.text for s in razdel.sentenize(text)]
         prediction = self.get_summary(
             sentences, summary_size=summary_size, threshold=threshold)
         prediction = " ".join(prediction)
         return prediction
     except Exception:
         return ''
예제 #9
0
def prepare_text(text):
    try:
        title, news_text = text.split('\n', 1)
    except ValueError:
        # Если в тексте один абзац и нет заголовка,
        # вернем пустой тайтл, и возмем заголовок из заголовка письма на уровне выше
        title = ""
        news_text = text
    news_text = news_text.replace('\n', ' ')
    sentences = [i.text for i in sentenize(news_text)]
    return title, sentences
예제 #10
0
 def save_syntax_analysis_by_text(self,
                                  text,
                                  file,
                                  is_many_sentences=False):
     f = open(file, 'a')
     sys.stdout = f
     print('-' * 100)
     if text != 'None':
         if not is_many_sentences:
             chunk = list()
             for sent in sentenize(text):
                 tokens = [_.text for _ in tokenize(sent.text)]
                 chunk.append(tokens)
             markup = next(self.syntax.map(chunk))
             words, deps = list(), list()
             for token in markup.tokens:
                 words.append(token.text)
                 source = int(token.head_id) - 1
                 target = int(token.id) - 1
                 if source > 0 and source != target:
                     deps.append([source, target, token.rel])
             show_markup(words, deps)
         else:
             for sentence in text.split('.'):
                 if len(sentence.split()) > 5:
                     chunk = list()
                     for sent in sentenize(sentence):
                         tokens = [_.text for _ in tokenize(sent.text)]
                         chunk.append(tokens)
                     markup = next(self.syntax.map(chunk))
                     words, deps = list(), list()
                     for token in markup.tokens:
                         words.append(token.text)
                         source = int(token.head_id) - 1
                         target = int(token.id) - 1
                         if source > 0 and source != target:
                             deps.append([source, target, token.rel])
                     show_markup(words, deps)
     else:
         print('None')
     print('-' * 100)
예제 #11
0
def extract_sentnces_with_names(text):
    text = re.sub("\\s+", " ", text)

    if not text:
        return {}

    # syntax extraction:
    chunk = []
    for sent in sentenize(text):
        tokens = [_ for _ in nltk.word_tokenize(sent.text)]
        chunk.append(tokens)

    markup = next(syntax.map(chunk))

    words = [token.text for token in markup.tokens]
    deps = []
    for token in markup.tokens:
        source = int(token.head_id) - 1
        target = int(token.id) - 1
        if source >= 0 and source != target:  # skip root, loops
            deps.append([source, target, token.rel])

    # get from sentence only obj which starts with upper case:
    obj_to_connections = {}
    for dep in deps:
        obj = words[dep[-2]]
        if 'obj' in dep and obj[0].isupper():
            obj_to_connections[mystem.lemmatize(obj)[0]] = dep[:len(dep) - 2]

    # cut init sentence to feed to semantic neural net:
    for key, value in obj_to_connections.items():
        for dep in deps:
            dep = dep[:len(dep) - 2]
            is_intesected = set(value).intersection(set(dep))
            if is_intesected and is_values_lower(
                    value, is_intesected) and dep[0] == value[0]:
                value.extend(dep)

    # distinct:
    for key, value in obj_to_connections.items():
        obj_to_connections[key] = list(set(value))

    # transfrom data to dict: obj -> sentence piece
    words_from_sentence = text.split(" ")
    for key, value in obj_to_connections.items():
        final_slice = ''
        for v in value:
            final_slice += words_from_sentence[v] + ' '

        obj_to_connections[key] = final_slice

    return obj_to_connections
예제 #12
0
def main(args):
    os.makedirs(args.outdir, exist_ok=True)

    # pos_model = build_model(configs.morpho_tagger.UD2_0.morpho_ru_syntagrus_pymorphy, download=True)
    pos_model = build_model(
        configs.morpho_tagger.BERT.morpho_ru_syntagrus_bert, download=True)
    syntax_model = build_model(configs.syntax.syntax_ru_syntagrus_bert,
                               download=True)

    for in_path in glob.glob(args.inglob, recursive=True):
        try:
            print(in_path)

            docname = os.path.splitext(os.path.basename(in_path))[0]
            out_path = os.path.join(args.outdir, docname + '.pickle')

            if os.path.exists(out_path) and not args.f:
                print('Already processed')
                continue

            with open(in_path, 'r') as f:
                full_text = clean_text(f.read())

            sentences_spans = list(sentenize(full_text))
            sentences_spans = [
                split_sent for sent in sentences_spans for split_sent in
                split_long_sentence(sent, max_len=args.max_sent_len)
            ]
            sentences_texts = [s.text for s in sentences_spans]
            sentences_pos = pos_model.batched_call(sentences_texts,
                                                   batch_size=args.batch_size)
            sentences_syntax = syntax_model.batched_call(
                sentences_texts, batch_size=args.batch_size)
            assert len(sentences_spans) == len(sentences_pos) == len(
                sentences_syntax)

            doc_sentences = [
                dict(span=(span.start, span.stop),
                     text=span.text,
                     pos=pos,
                     syntax=synt) for span, pos, synt in zip(
                         sentences_spans, sentences_pos, sentences_syntax)
            ]
            with open(out_path, 'wb') as f:
                pickle.dump(doc_sentences, f)
        except Exception as ex:
            print(
                f'Failed to process {in_path} due to {ex}\n{traceback.format_exc()}'
            )
예제 #13
0
def get_lemmatized_sentences(texts):
    mystem = Mystem(entire_input=False)
    result = []
    original_sentences = []
    for text in tqdm(texts):
        text_repr = []
        original_repr = []
        for sent in razdel.sentenize(text):
            lemmas = mystem.lemmatize(sent.text)
            text_repr.append(lemmas)
            original_repr.append(sent.text)

        result.append(text_repr)
        original_sentences.append(original_repr)
    return result, original_sentences
예제 #14
0
def get_stead_sent_pairs(text: str) -> List[str]:
    """Разбиваем текст на пары последовательных предложений (первое и второе)

    Args:
        text (str): Входной текст

    Returns:
        List[str]: Список пар предложений
    """
    sents = [sent.text for sent in sentenize(text)]
    sent_pairs = []
    for i in range(0, len(sents)):
        if i+1 < len(sents):
            sent_pairs.append(sents[i] + ' ' + sents[i+1])
    return sent_pairs
예제 #15
0
    def get_texts(dataset):
        texts = []
        for text in dataset["text"]:
            for sentence in sentenize(text):
                texts.append([
                    token.text.lower() for token in tokenize(sentence.text)
                    if token.text not in punctuation
                ])

        for title in dataset["title"]:
            texts.append([
                token.text.lower() for token in tokenize(title)
                if token.text not in punctuation
            ])
        return texts
예제 #16
0
    def describe(self, fileids=None, categories=None):
        """
        Обходит все документы и возвращает словарь с разнообразными
        оценками, описывающими состояние корпуса
        """
        started = time.time()

        # Структура для подсчета
        counts = Counter()

        # Выполнить обход всех новостей из csv, выделить предложения и слова,
        # подсчитать их
        for i in self.docs(fileids, categories):
            counts['rows'] += 1
            sents = list(sentenize(i['text']))
            sentence = [_.text for _ in sents]

            for sent in sentence:
                counts['sents'] += 1
                if len(sent) == 0:
                    continue
                for word in [_.text for _ in list(tokenize(sent))]:
                    counts['words'] += 1
                    self.__tokens[word.lower()] += 1

        # Определить число файлов и категорий в корпусе
        n_fileids = len(self.resolve(fileids, categories) or self.fileids())
        n_topics = len(self.categories(self.resolve(fileids, categories)))

        # Составить список новостей
        list_news = list(self.listcolumns(fileids, categories))
        # Вернуть структуру с информацией
        return {
            'Количество файлов': n_fileids,
            'Количество источников новостей': n_topics,
            'Количество обработанных новостей': counts['rows'],
            'Количество предложений': counts['sents'],
            'Количество слов': counts['words'],
            'Количество токенов (словарь)': len(self.gettokens),
            'Коэффициент лексического разнообразия (lexical diversity)': float(counts['words']) / float(len(self.gettokens)),
            'Среднее количество новостей по отношению к файлам': float(counts['rows']) / float(n_fileids),
            'Среднее количество предложений в новостях': float(counts['sents']) / float(counts['rows']),
            'Начальная дата в обработке': min(self.listcolumns(fileids, categories, 'date')),
            'Конечная дата в обработке': max(self.listcolumns(fileids, categories, 'date')),
            'Количество повторяющихся новостей': len(list_news) - len(set(list_news)),
            'Количество пустых новостных элементов': len([item for item in list_news if len(item) == 0]),
            'Время обработки в секундах': time.time() - started,
        }
예제 #17
0
    def _parse_text(cls, text):
        text = text.replace('(З)', '(3)').replace('(ЗЗ)', '(33)').replace('(ЗО)', '(30)')
        if not _SENTENCE_START_PATTERN.search(text):
            sentences = [sent.text for sent in sentenize(text)]
        else:
            sentences = []
            prev_match = None
            for match in _SENTENCE_START_PATTERN.finditer(text):
                if prev_match:
                    sentences.append(text[prev_match.end(): match.start()].strip())
                prev_match = match
            sentences.append(text[prev_match.end():].strip())

        sentences[-1], preposition, author = cls._extract_author(sentences[-1])

        return sentences, preposition, author
예제 #18
0
def get_diff_sent_pairs(text: str) -> List[str]:
    """Разбиваем текст на пары предложений: соединяем все со всеми последующими

    Args:
        text (str): Входной текст

    Returns:
        List[str]: Список пар предложений
    """
    sents = [sent.text for sent in sentenize(text)]
    sent_pairs = []
    for posA, sentA in enumerate(sents):
        for posB, sentB in enumerate(sents):
            if (sentA != sentB) and (posA < posB):
                sent_pairs.append(sentA + ' ' + sentB)
    return sent_pairs
def nlp(text):
    chunks=[]
    lemmaSent=[]
    Doc=[]
    for sent in sentenize(text):
        tokens = [_.text for _ in tokenize(sent.text)]
        chunks.append(tokens)

    for chunk in chunks:
        filteredChunk=list(filter(lambda a: a != ' ', chunk))
        markup = next(morph.map([filteredChunk]))

        for token in markup.tokens:
            tokentext=token.text
            Doc.append(Tokens(tokentext,m.lemmatize(tokentext)[0],token.pos))
    return Doc
예제 #20
0
    def split(self, snt=''):
        """
        Делит текст на набор предложений или предложение на набор токенов - слово, знак препинания

        :param snt: пустая строка, если нужно разделить текст на предложения, строка с предложением,
        если нужно разделить предложение
        :return: список предложений или список токенов в предложении
        """
        if not snt:
            if self.text:
                lst = list(sentenize(self.text))
                self.sentences = [_.text for _ in lst]
            else:
                raise KeyError
        else:
            tokens = list(tokenize(snt))
            return [_.text for _ in tokens]
예제 #21
0
    def calc_context(self, window_size=5):
        """
        
        Calculates neighbours of each word with fixed window size. 
        Creates context dict with indexes of center and context words.
        
        """
        self._window_size = window_size

        for text_dict in self._corpora:
            sentences = razdel.sentenize(text_dict["text_data"])

            for sent in sentences:
                tokens = [
                    token for token in self._tokenizer.tokenize(sent.text)
                    if token.isalpha()
                ]
                idx_tokens = [
                    self._word2ind[token]
                    if token in self._word2ind else self._word2ind["<UNK>"]
                    for token in tokens
                ]

                for side_idx in range(self._window_size // 2):
                    left_side = idx_tokens[side_idx +
                                           1:(self._window_size // 2) +
                                           side_idx + 1]
                    right_side = idx_tokens[-(self._window_size // 2) -
                                            side_idx - 1:-side_idx - 1]
                    self._context_dict[idx_tokens[side_idx]].extend(left_side)
                    self._context_dict[idx_tokens[-side_idx -
                                                  1]].extend(right_side)

                if len(idx_tokens) > self._window_size - 1:
                    for i in range(self._window_size // 2,
                                   len(idx_tokens) - self._window_size // 2):
                        for step in range(1, self._window_size // 2 + 1):
                            self._context_dict[idx_tokens[i]].append(
                                idx_tokens[i + step])
                            self._context_dict[idx_tokens[i]].append(
                                idx_tokens[i - step])

        del self._context_dict[0]  #element for <UNK> token
        del self._context_dict[1]  #element for <PAD> token
예제 #22
0
 def preprocess_text(self, text):
     try:
         # split on sentences
         sents = [
             sent for (start_pos, end_pos, sent) in razdel.sentenize(text)
         ]
         # split on words, transform to lower case and lemmatize
         words = [
             word.lower().strip() for sent in sents
             for word in self.lemmatizer.lemmatize(sent)
         ]
         # remove stop words
         words = [
             word for word in words
             if word not in self.stopwords and ":" not in word
         ]
     except:
         words = []
     return words
예제 #23
0
 def __iter_row(self, from_cursor, dt_now, size):
     rows = from_cursor.fetchmany(size)
     for row in rows:
         sents = list(sentenize(str(row[5]).lower()))
         clean_sentence = [clean_text(_.text) for _ in sents]
         lemma_sentence = [lemmatize(sent) for sent in clean_sentence]
         yield (
             str(row[0]),
             str(row[1]),
             "4847c8c7-a14f-4d59-8f62-a1c622db4aab",
             "4847c8c7-a14f-4d59-8f62-a1c622db4aab",
             row[2],
             row[3],
             row[4],
             str(lemma_sentence),
             dt_now,
             dt_now,
             "1900-01-01 00:00:00",
         )
예제 #24
0
def disambiguate_span(spans: list, paragraphs: list, questions: list,
                      answers: list, m):
    """Функция возвращает список однозначных спанов ответов для всех вопросов"""
    result_spans = []

    for i, entry in tqdm(enumerate(spans)):

        if len(entry) > 1:

            spans_list = entry
            sent_chunks = {
                i.text: (i.start, i.stop)
                for i in list(sentenize(paragraphs[i]))
            }
            question = lemmatize(questions[i], m)
            max_overlap = 0
            optimal_span = ()

            for span in spans_list:
                for sent, chunk in sent_chunks.items():
                    if span[0] >= chunk[0] and span[1] <= chunk[1]:
                        cur_span = (span[0] - chunk[0],
                                    span[0] - chunk[0] + (span[1] - span[0]))
                        left_context, right_context = sent[:cur_span[0]], sent[
                            cur_span[1]:]
                        left_context = lemmatize(left_context, m)
                        right_context = lemmatize(right_context, m)
                        overlap = len(
                            intersection(left_context + right_context,
                                         question))
                        jaccar = overlap / (len(left_context + right_context) +
                                            len(question) - overlap)
                        if jaccar > max_overlap:
                            max_overlap = jaccar
                            optimal_span = span

            result_spans.append(optimal_span)

        else:
            result_spans.append(entry[0])

    return result_spans
def process_record(rec):
    sentences = []

    text = text_normalizer(rec)
    for sentence in sentenize(text):
        txt = sentence.text

        tokens = re_tokenizer(txt)
        tokens.insert(0, BOS_TOKEN)
        tokens.append(EOS_TOKEN)
        tokens = [isnum(token) for token in tokens]

        # single lang sentences and minimum length of the sentence is 3 tokens, ignoring SPECIAL
        latin_block = any([islatin(token) for token in tokens])

        if not latin_block:
            line = ' '.join(tokens)
            sentences.append(line)

    return sentences
예제 #26
0
def get_title(text: str) -> str:
    """Генеарция заголовка

    Args:
        text ([str]): Текста поста

    Returns:
        [str]: Заголовок
    """
    text = text.replace('\n', '.').replace('..', '.')
    text = preprocess_text(text)
    sents = list(sentenize(text))
    if len(sents[0].text) > 0:
        title = sents[0].text
    elif len(sents) > 2:
        title = sents[1].text
    else:
        title = '(без заголовка)'
    if len(title.split(',')[0].replace('.', '').split()) == 1:
        return title.replace('.', '')
    return title.split(',')[0].replace('.', '')
예제 #27
0
    def _extract_author(last_sentence):
        match = _AUTHOR_PATTERN.search(last_sentence)
        if match:
            last_sentence = last_sentence[:match.start()].strip()
            author = match.group(3).strip()
            preposition = match.group(1)
        else:
            splitted_last_sentence = list(sentenize(last_sentence))
            last_sentence = splitted_last_sentence[0].text.strip()
            appended_info = ' '.join(part.text for part in splitted_last_sentence[1:])
            author_match = _AUTHOR_FALLBACK_PATTERN.search(appended_info)
            if author_match:
                author = author_match.group(1).strip()
            else:
                author = None
            preposition = None

        if author and (len(author) < 4 or '.' not in author):
            author = None

        return last_sentence, preposition, author
예제 #28
0
파일: main.py 프로젝트: natasha/nerus
def merge(docs, ners, morphs, syntaxes):
    for doc_id, doc in enumerate(docs):
        ner = next(ners)
        sents = sentenize(doc.text)
        for sent_index, sent in enumerate(sents):
            morph = next(morphs)
            syntax = next(syntaxes)

            if len(morph.tokens) != len(syntax.tokens):
                # long sents
                # empty sent
                # syntax mask missaligned
                # ~250 sents / 100 000 texts
                continue

            spans = list(sent_spans(sent, ner.spans))
            words = markup_words(morph)
            tokens = find_tokens(sent.text, words)
            tags = spans_bio(tokens, spans)

            tokens = list(merge_tokens(morph.tokens, syntax.tokens, tags))
            yield MergeRecord(doc_id, sent_index, sent.text, spans, tokens)
def build_oracle_summary_greedy(text,
                                gold_summary,
                                calc_score,
                                lower=True,
                                max_sentences=30):
    output = {"text": text, "summary": gold_summary}
    gold_summary = gold_summary.lower() if lower else gold_summary
    original_sentences = [s.text for s in razdel.sentenize(text)]
    sentences = [s.lower() if lower else s
                 for s in original_sentences][:max_sentences]

    def indices_to_text(indices):
        return " ".join([sentences[index] for index in sorted(list(indices))])

    n_sentences = len(sentences)

    scores = []
    final_score = -1.0
    final_indices = set()
    for _ in range(n_sentences):
        for i in range(n_sentences):
            if i in final_indices:
                continue
            indices = copy.copy(final_indices)
            indices.add(i)
            summary = indices_to_text(indices)
            scores.append((calc_score(summary, gold_summary), indices))
        # If metrics didn't increase in outer loop, stop
        best_score, best_indices = max(scores)
        scores = []
        if best_score <= final_score:
            break
        final_score, final_indices = best_score, best_indices
    oracle_indices = [
        1 if i in final_indices else 0 for i in range(len(sentences))
    ]
    output.update({"sentences": sentences, "oracle": oracle_indices})
    return output
예제 #30
0
    def predict_from_model(self, task):
        candidates = {}
        text = task["text"]
        for k, v in replaces.items():
            text = text.replace(k, v)
        text = text.replace(".", ". ")
        all_variants = []

        for _, _, sentence in razdel.sentenize(text):
            sentence = sentence.lower()
            if not re.search(pattern, sentence):
                continue

            sentence = re.sub(pattern, r" \1_ ", sentence)
            try:
                encoded, is_targets = self.encode_data(sentence)
            except Together as e:
                word = e.args[0]
                return word
            targets, attentions = self.insert_target(encoded)
            for encode, X, attention, is_target in zip(encoded, targets,
                                                       attentions, is_targets):
                if is_target:
                    all_variants.append(is_target)
                    X = torch.from_numpy(np.array([X]))
                    attention = torch.from_numpy(np.array([attention]))
                    output = self.rubert(X, attention)
                    y_pred = output.argmax(
                        dim=1).cpu().data.numpy().flatten()[0]
                    if y_pred:
                        candidates[is_target] = (output.max(
                            dim=1).values.cpu().data.numpy().flatten()[0])

        if candidates:
            return sorted(candidates.items(), key=lambda x: x[1],
                          reverse=True)[0][0]
        else:
            return random.choice(all_variants)