コード例 #1
0
ファイル: test_vocab.py プロジェクト: alwc/iSeqL
 def test_vocab_creation(self) -> None:
     '''
     make sure vocab creation is able to create vocabulary objects
     with the correct words
     '''
     vocab = build_vocab(SAMPLE_WORDS)
     self._has_words(vocab, UNIQUE_WORDS)
     for token in constants.SPECIAL_TOKENS:
         if token == constants.UNKNOWN_TOKEN:
             continue
         self.contains(vocab, token)
コード例 #2
0
ファイル: test_vocab.py プロジェクト: alwc/iSeqL
    def test_unk(self) -> None:
        '''
        Test to make sure the vocab object supports unking
        '''
        vocab = build_vocab(SAMPLE_WORDS)

        unk_index = vocab(constants.UNKNOWN_TOKEN)

        # encode
        assert vocab('random_word_here') == unk_index

        # decode
        assert vocab.get_word(unk_index) == constants.UNKNOWN_TOKEN
コード例 #3
0
ファイル: utils.py プロジェクト: alwc/iSeqL
 def _construct_vocab() -> Vocab:
     dl = construct_dataloader(SAMPLE_DATASET)
     return build_vocab(dl.word_list)
コード例 #4
0
ファイル: utils.py プロジェクト: alwc/iSeqL
 def _construct_tag_vocab() -> Vocab:
     dl = construct_dataloader(SAMPLE_DATASET)
     return build_vocab(dl.categories)
コード例 #5
0
    def prepare_csv(self):
        if self.load():
            print("load database manager state")
            return

        # necessary fields from data
        row_info = self.configuration.get_key('data_schema/rows')
        row_types = self.configuration.get_key('data_schema/row_types')
        text_field_name = self.configuration.get_key('data_schema/text_field')
        id_field_name = self.configuration.get_key('data_schema/id_field')
        label_field_name = self.configuration.get_key(
            'data_schema/label_field')

        has_header = self.configuration.get_key('data_schema/includes_header')

        sentence_counter = 0
        word_list = []
        with open(self.raw_data_csv) as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',')
            for row in csv_reader:
                if has_header:
                    # skip over headers in the data when included
                    has_header = False
                    continue
                row_data = {}

                # parse the row into the proper row names
                # extract the necessary fields from the configuration
                for row_name, row_value, row_type in zip(
                        row_info, row, row_types):
                    if row_type == 'int':
                        row_data[row_name] = int(row_value)
                    elif row_type == 'float':
                        row_data[row_name] = float(row_value)
                    else:
                        row_data[row_name] = row_value

                entry_id = int(row_data[id_field_name])
                text_data = row_data[text_field_name]
                label = row_data[
                    label_field_name] if label_field_name is not None else None

                # include computed metrics of data
                sent_scores: dict = NLTK_SIA.polarity_scores(text_data)
                row_data['sent_negative_score'] = sent_scores['neg']
                row_data['sent_neutral_score'] = sent_scores['neu']
                row_data['sent_positive_score'] = sent_scores['pos']
                row_data['sent_compound_score'] = sent_scores['compound']

                self.entire_database[entry_id] = row_data

                self.entry_to_sentences[entry_id] = []
                context_split, spans, sent_scores = self.context_func(
                    text_data)
                for i, (model_entry, sent_span, sent_sent_score) in enumerate(
                        zip(context_split, spans, sent_scores)):
                    if len(model_entry) == 0:
                        print(
                            f'Error processing id: {(entry_id, i)} with sentence: {model_entry}'
                        )
                        continue
                    self.entry_to_sentences[entry_id].append(sentence_counter)
                    self.database[sentence_counter] = (model_entry, None)
                    self.ground_truth_database[sentence_counter] = (
                        model_entry, self._filter_label(label, 'ADR'))
                    self.sentence_data[sentence_counter] = (
                        # sentence ranges
                        sent_span,
                        # sentence sentiments
                        sent_sent_score,
                    )

                    sent_spans = PST.span_tokenize(' '.join(model_entry))

                    sentence_counter += 1
                    word_list.extend(model_entry)

        self.vocab = vocab.build_vocab(word_list)
        self.save()