コード例 #1
0
ファイル: datasets.py プロジェクト: felix-schneider/xnmtorch
    def __init__(
            self,
            question_path,
            paragraph_path,
            ratio,
            batch_size,
            vocab: Vocab = Ref("model.vocab"),
            batch_first=Ref("model.batch_first", True),
    ):
        self.vocab = vocab
        question = Field(include_lengths=True,
                         batch_first=batch_first,
                         pad_token=vocab.pad_token)
        question.vocab = vocab
        paragraph = Field(batch_first=batch_first, pad_token=vocab.pad_token)
        paragraph.vocab = vocab
        paragraphs = NestedField(paragraph, include_lengths=True)
        paragraphs.vocab = vocab
        target = Field(sequential=False, use_vocab=False, is_target=True)

        fields = [("question", question), ("paragraphs", paragraphs),
                  ("target", target)]
        examples = []
        with open(paragraph_path) as paragraph_file, open(
                question_path) as question_file:
            for q in question_file:
                q = q.strip()
                ps = [paragraph_file.readline().strip() for _ in range(ratio)]
                examples.append(Example.fromlist([q, ps, 0], fields))

        BaseIRDataset.__init__(self, ratio, batch_size, batch_first)
        TorchTextDataset.__init__(self, examples, fields)
コード例 #2
0
def make_fields(vocab_count, binary=True):
    text_field = Field(batch_first=True,
                       include_lengths=True,
                       tokenize=lambda x: x.split(' '))
    text_field.vocab = Vocab(vocab_count['text'])
    char_nesting_field = Field(batch_first=True, tokenize=list)
    char_field = NestedField(char_nesting_field,
                             tokenize=lambda x: x.split(' '))
    char_nesting_field.vocab = Vocab(vocab_count['chars'])
    char_field.vocab = Vocab(vocab_count['chars'])
    pos1_field = Field(batch_first=True, sequential=False, use_vocab=False)
    pos2_field = Field(batch_first=True, sequential=False, use_vocab=False)
    pos1_rel_field = Field(sequential=True, batch_first=True)
    pos1_rel_field.vocab = Vocab(vocab_count['pos1_rel'])
    pos2_rel_field = Field(sequential=True, batch_first=True)
    pos2_rel_field.vocab = Vocab(vocab_count['pos2_rel'])
    if binary:
        label_field = Field(sequential=False, batch_first=True)
    else:
        label_field = Field(sequential=False, batch_first=True)
    label_field.vocab = Vocab(vocab_count['relation'], specials=[])
    reltype_field = Field(batch_first=True, sequential=False)
    reltype_field.vocab = Vocab(vocab_count['rel_type'])
    fields_dict = {
        'text': [('text', text_field), ('chars', char_field)],
        'pos1': ('pos1', pos1_field),
        'pos2': ('pos2', pos2_field),
        'pos1_rel': ('pos1_rel', pos1_rel_field),
        'pos2_rel': ('pos2_rel', pos2_rel_field),
        'relation': ('relation', label_field),
        'rel_type': ('rel_type', reltype_field)
    }
    return fields_dict