Пример #1
0
    def load_dataset(unigram_field,bigram_field,label_field,batch_size,data_dir):
        train = datasets.SequenceTaggingDataset(path=os.path.join(data_dir,'train.tsv'),
                                                 fields=[('unigram',unigram_field),('label',label_field),('fwd_bigram',bigram_field),('back_bigram',bigram_field)],
                                                 )
        dev = datasets.SequenceTaggingDataset(path=os.path.join(data_dir,'dev.tsv'),
                                              fields=[('unigram',unigram_field),('label',label_field),('fwd_bigram',bigram_field),('back_bigram',bigram_field)])
        unigram_field.build_vocab(train,dev,min_freq=1)
        bigram_field.build_vocab(train,dev,min_freq=5)
        label_field.build_vocab(train,dev)

        train_iter = data.BucketIterator(train,
                                         train=train,
                                         batch_size=batch_size,
                                         sort_key=lambda x:len(x.unigram),
                                         device=device,
                                         sort_within_batch=True,
                                         repeat=False,
                                         )

        dev_iter = data.BucketIterator(dev,
                                       batch_size=32,
                                       device=device,
                                       sort=False,
                                       shuffle=False,
                                       repeat=False)

        return train_iter,dev_iter
    def __init__(self, model_path: str, train_path: str, wordemb_path: str,
                 charemb_path: str, hidden_size: int):
        """
        :param model_path: trained model file path (.pth)
        :param train_path: file path used training
        :param wordemb_path: path of word embedding used training
        :param charemb_path: path of char embedding used training
        :param hidden_size: size of hidden layer
        """

        self.mecab = MeCab.Tagger('-Owakati')
        self.WORD = data.Field(batch_first=True)
        self.CHAR = data.Field(batch_first=True)
        self.LABEL = data.Field(batch_first=True)
        self.fields = [('char', self.CHAR), ('word', self.WORD),
                       ('label', self.LABEL)]
        self.dataset = datasets.SequenceTaggingDataset(path=train_path,
                                                       fields=self.fields,
                                                       separator='\t')
        self.CHAR.build_vocab(self.dataset, vectors=Vectors(charemb_path))
        self.WORD.build_vocab(self.dataset, vectors=Vectors(wordemb_path))
        self.LABEL.build_vocab(self.dataset)
        self.model = BLSTMCRF(len(self.LABEL.vocab.itos), hidden_size, 0.0,
                              self.WORD.vocab.vectors.size()[1],
                              self.CHAR.vocab.vectors.size()[1])
        self.model.load(model_path)
Пример #3
0
    def __init__(self, text_path: str, wordemb_path: str, charemb_path: str,
                 device: str):
        """
        The form of dataset
        想定しているデータセットの形
        私は白い恋人を食べました
        私  私  O
        は  は  O
        白  白い    B-PRO
        い  白い    I-PRO
        恋  恋人    I-PRO
        人  恋人    I-PRO
        を  を  O
        食  食べ    O
        べ  食べ    O
        ま  まし    O
        し  まし    O
        た  た  O
        """

        self.WORD = data.Field(batch_first=True)
        self.CHAR = data.Field(batch_first=True)
        self.LABEL = data.Field(batch_first=True)
        self.fields = [('char', self.CHAR), ('word', self.WORD),
                       ('label', self.LABEL)]
        self.dataset = datasets.SequenceTaggingDataset(path=text_path,
                                                       fields=self.fields,
                                                       separator='\t')
        self.CHAR.build_vocab(self.dataset, vectors=Vectors(charemb_path))
        self.WORD.build_vocab(self.dataset, vectors=Vectors(wordemb_path))
        self.LABEL.build_vocab(self.dataset)
        self.device = device
Пример #4
0
def create_dataset(data_path):
    TEXT = data.Field()
    START = data.Field()
    END = data.Field()
    LABEL = data.Field()
    dataset = datasets.SequenceTaggingDataset(path=data_path,
                                              fields=[('label', LABEL),
                                                      ('start', START),
                                                      ('end', END),
                                                      ('text', TEXT)])
    LABEL.build_vocab(dataset)
    label_list = list(LABEL.vocab.freqs) + ["X"]
    return dataset, label_list
Пример #5
0
    def load_testset(unigram_field,bigram_field,label_field,test_path):
        test = datasets.SequenceTaggingDataset(path=test_path,
                                               fields=[('unigram',unigram_field),('label',label_field),('fwd_bigram',bigram_field),('back_bigram',bigram_field)])

        test_iter = data.BucketIterator(test,batch_size=32,train=False,shuffle=False,sort=False,device=device)
        return test_iter
Пример #6
0
    config = Namespace(**config, **vars(opt))
    logger = init_logger("torch", logging_path='')
    logger.info(config.__dict__)

    device, devices_id = misc_utils.set_cuda(config)
    config.device = device

    TEXT = data.Field(sequential=True, use_vocab=False, batch_first=True, unk_token=utils.UNK,
                      include_lengths=True, pad_token=utils.PAD, preprocessing=to_int, )
    # init_token=utils.BOS, eos_token=utils.EOS)
    LABEL = data.Field(sequential=True, use_vocab=False, batch_first=True, unk_token=utils.UNK,
                       include_lengths=True, pad_token=utils.PAD, preprocessing=to_int, )
    # init_token=utils.BOS, eos_token=utils.EOS)

    fields = [("text", TEXT), ("label", LABEL)]
    validDataset = datasets.SequenceTaggingDataset(path=os.path.join(config.data, 'valid.txt'),
                                                   fields=fields)
    valid_iter = data.Iterator(validDataset,
                               batch_size=config.batch_size,
                               sort_key=lambda x: len(x.text),  # field sorted by len
                               sort=True,
                               sort_within_batch=True,
                               repeat=False
                               )

    src_vocab = utils.Dict()
    src_vocab.loadFile(os.path.join(config.data, "src.vocab"))
    tgt_vocab = utils.Dict()
    tgt_vocab.loadFile(os.path.join(config.data, "tgt.vocab"))

    if config.model == 'bilstm_crf':
        model = BiLSTM_CRF(src_vocab.size(), tgt_vocab.size(), config)
Пример #7
0
TEXT_wmt16.build_vocab(train_sst.text, max_size=10000)
print('vocab length (including special tokens):', len(TEXT_wmt16.vocab))

train_iter_wmt16 = data.BucketIterator(wmt16_data,
                                       batch_size=args.batch_size,
                                       repeat=False)
# ============================ WMT16 ============================ #

# ============================ English Web Treebank (Answers) ============================ #

TEXT_answers = data.Field(pad_first=True, lower=True)

treebank_path = './.data/eng_web_tbk/answers/conll/answers_penntrees.dev.conll'

train_answers = datasets.SequenceTaggingDataset(path=treebank_path,
                                                fields=((None, None),
                                                        ('text',
                                                         TEXT_answers)))

TEXT_answers.build_vocab(train_sst.text, max_size=10000)
print('vocab length (including special tokens):', len(TEXT_answers.vocab))

# make iterators
train_iter_answers = data.BucketIterator.splits((train_answers, ),
                                                batch_size=args.batch_size,
                                                repeat=False)[0]

# ============================ English Web Treebank (Answers) ============================ #

# ============================ English Web Treebank (Email) ============================ #

TEXT_email = data.Field(pad_first=True, lower=True)