예제 #1
0
    def setup_data_loader(self):
        attr_indexes=common.get_attribute_values(self.args.attr_indexes)

        if self.task == constants.TASK_SEG or self.task == constants.TASK_SEGTAG:
            self.data_loader = segmentation_data_loader.SegmentationDataLoader(
                token_index=self.args.token_index,
                attr_indexes=attr_indexes,
                attr_depths=common.get_attribute_values(self.args.attr_depths, len(attr_indexes)),
                attr_target_labelsets=common.get_attribute_labelsets(
                    self.args.attr_target_labelsets, len(attr_indexes)),
                attr_delim=self.args.attr_delim,
                use_bigram=(self.hparams['bigram_embed_dim'] > 0),
                use_chunk_trie=(True if self.hparams['feature_template'] else False),
                bigram_max_vocab_size=self.hparams['bigram_max_vocab_size'],
                bigram_freq_threshold=self.hparams['bigram_freq_threshold'],
                unigram_vocab=(self.unigram_embed_model.wv if self.unigram_embed_model else set()),
                bigram_vocab=(self.bigram_embed_model.wv if self.bigram_embed_model else set()),
            )
        else:
            self.data_loader = tagging_data_loader.TaggingDataLoader(
                token_index=self.args.token_index,
                attr_indexes=attr_indexes,
                attr_depths=common.get_attribute_values(self.args.attr_depths, len(attr_indexes)),
                attr_chunking_flags=common.get_attribute_boolvalues(
                    self.args.attr_chunking_flags, len(attr_indexes)),
                attr_target_labelsets=common.get_attribute_labelsets(
                    self.args.attr_target_labelsets, len(attr_indexes)),
                attr_delim=self.args.attr_delim,
                lowercasing=self.hparams['lowercasing'],
                normalize_digits=self.hparams['normalize_digits'],
                token_freq_threshold=self.hparams['token_freq_threshold'],
                token_max_vocab_size=self.hparams['token_max_vocab_size'],
                unigram_vocab=(self.unigram_embed_model.wv if self.unigram_embed_model else set()),
            )
예제 #2
0
    def show_training_data(self):
        train = self.train
        dev = self.dev
        self.log('### Loaded data')
        self.log('# train: {} ... {}\n'.format(train.inputs[0][0], train.inputs[0][-1]))
        self.log('# train_gold: {} ... {}\n'.format(train.outputs[0][0], train.outputs[0][-1]))
        t2i_tmp = list(self.dic.tables[constants.UNIGRAM].str2id.items())
        self.log('# token2id: {} ... {}\n'.format(t2i_tmp[:10], t2i_tmp[len(t2i_tmp)-10:]))
        if self.dic.has_table(constants.BIGRAM):
            b2i_tmp = list(self.dic.tables[constants.BIGRAM].str2id.items())
            self.log('# bigram2id: {} ... {}\n'.format(b2i_tmp[:10], b2i_tmp[len(b2i_tmp)-10:]))
        if self.dic.has_trie(constants.CHUNK):
            id2chunk = self.dic.tries[constants.CHUNK].id2chunk
            n_chunks = len(self.dic.tries[constants.CHUNK])
            c2i_head = [(id2chunk[i], i) for i in range(0, min(10, n_chunks))]
            c2i_tail = [(id2chunk[i], i) for i in range(max(0, n_chunks-10), n_chunks)]
            self.log('# chunk2id: {} ... {}\n'.format(c2i_head, c2i_tail))
        if self.dic.has_table(constants.SEG_LABEL):
            id2seg = {v:k for k,v in self.dic.tables[constants.SEG_LABEL].str2id.items()}
            self.log('# label_set: {}\n'.format(id2seg))

        attr_indexes=common.get_attribute_values(self.args.attr_indexes)
        for i in range(len(attr_indexes)):
            if self.dic.has_table(constants.ATTR_LABEL(i)):
                id2attr = {v:k for k,v in self.dic.tables[constants.ATTR_LABEL(i)].str2id.items()}
                self.log('# {}-th attribute labels: {}\n'.format(i, id2attr))
        
        self.report('[INFO] vocab: {}'.format(len(self.dic.tables[constants.UNIGRAM])))
        self.report('[INFO] data length: train={} devel={}'.format(
            len(train.inputs[0]), len(dev.inputs[0]) if dev else 0))
예제 #3
0
    def show_training_data(self):
        train = self.train
        dev = self.dev
        self.log('### Loaded data')
        self.log('# train: {} ... {}\n'.format(train.inputs[0][0],
                                               train.inputs[0][-1]))
        self.log('# train_gold_attr: {} ... {}\n'.format(
            train.outputs[0][0], train.outputs[0][-1]))
        t2i_tmp = list(self.dic.tables[constants.UNIGRAM].str2id.items())
        self.log('# token2id: {} ... {}\n'.format(t2i_tmp[:10],
                                                  t2i_tmp[len(t2i_tmp) - 10:]))

        attr_indexes = common.get_attribute_values(self.args.attr_indexes)
        for i in range(len(attr_indexes)):
            if self.dic.has_table(constants.ATTR_LABEL(i)):
                id2attr = {
                    v: k
                    for k, v in self.dic.tables[constants.ATTR_LABEL(
                        i)].str2id.items()
                }
                self.log('# {}-th attribute labels: {}\n'.format(i, id2attr))

        self.report('[INFO] vocab: {}'.format(
            len(self.dic.tables[constants.UNIGRAM])))
        self.report('[INFO] data length: train={} devel={}'.format(
            len(train.inputs[0]),
            len(dev.inputs[0]) if dev else 0))
예제 #4
0
 def setup_data_loader(self):
     attr_indexes = common.get_attribute_values(self.args.attr_indexes)
     self.data_loader = attribute_annotation_data_loader.AttributeAnnotationDataLoader(
         token_index=self.args.token_index,
         label_index=self.args.label_index,
         attr_indexes=attr_indexes,
         attr_depths=common.get_attribute_values(self.args.attr_depths,
                                                 len(attr_indexes)),
         attr_chunking_flags=common.get_attribute_boolvalues(
             self.args.attr_chunking_flags, len(attr_indexes)),
         attr_target_labelsets=common.get_attribute_labelsets(
             self.args.attr_target_labelsets, len(attr_indexes)),
         attr_delim=self.args.attr_delim,
         lowercasing=self.hparams['lowercasing'],
         normalize_digits=self.hparams['normalize_digits'],
     )
예제 #5
0
 def setup_data_loader(self):
     attr_indexes=common.get_attribute_values(self.args.attr_indexes)
     self.data_loader = parsing_data_loader.ParsingDataLoader(
         token_index=self.args.token_index,
         head_index=self.args.head_index,
         arc_index=self.args.arc_index,
         attr_indexes=attr_indexes,
         attr_depths=common.get_attribute_values(self.args.attr_depths, len(attr_indexes)),
         attr_chunking_flags=common.get_attribute_boolvalues(
             self.args.attr_chunking_flags, len(attr_indexes)),
         attr_target_labelsets=common.get_attribute_labelsets(
             self.args.attr_target_labelsets, len(attr_indexes)),
         attr_delim=self.args.attr_delim,
         use_arc_label=(self.task == constants.TASK_TDEP),
         lowercasing=self.hparams['lowercasing'],
         normalize_digits=self.hparams['normalize_digits'],
         token_freq_threshold=self.hparams['token_freq_threshold'],
         token_max_vocab_size=self.hparams['token_max_vocab_size'],
         unigram_vocab=(self.unigram_embed_model.wv if self.unigram_embed_model else set()),
     )