def __init__( self, question_path, paragraph_path, ratio, batch_size, vocab: Vocab = Ref("model.vocab"), batch_first=Ref("model.batch_first", True), ): self.vocab = vocab question = Field(include_lengths=True, batch_first=batch_first, pad_token=vocab.pad_token) question.vocab = vocab paragraph = Field(batch_first=batch_first, pad_token=vocab.pad_token) paragraph.vocab = vocab paragraphs = NestedField(paragraph, include_lengths=True) paragraphs.vocab = vocab target = Field(sequential=False, use_vocab=False, is_target=True) fields = [("question", question), ("paragraphs", paragraphs), ("target", target)] examples = [] with open(paragraph_path) as paragraph_file, open( question_path) as question_file: for q in question_file: q = q.strip() ps = [paragraph_file.readline().strip() for _ in range(ratio)] examples.append(Example.fromlist([q, ps, 0], fields)) BaseIRDataset.__init__(self, ratio, batch_size, batch_first) TorchTextDataset.__init__(self, examples, fields)
def make_fields(vocab_count, binary=True): text_field = Field(batch_first=True, include_lengths=True, tokenize=lambda x: x.split(' ')) text_field.vocab = Vocab(vocab_count['text']) char_nesting_field = Field(batch_first=True, tokenize=list) char_field = NestedField(char_nesting_field, tokenize=lambda x: x.split(' ')) char_nesting_field.vocab = Vocab(vocab_count['chars']) char_field.vocab = Vocab(vocab_count['chars']) pos1_field = Field(batch_first=True, sequential=False, use_vocab=False) pos2_field = Field(batch_first=True, sequential=False, use_vocab=False) pos1_rel_field = Field(sequential=True, batch_first=True) pos1_rel_field.vocab = Vocab(vocab_count['pos1_rel']) pos2_rel_field = Field(sequential=True, batch_first=True) pos2_rel_field.vocab = Vocab(vocab_count['pos2_rel']) if binary: label_field = Field(sequential=False, batch_first=True) else: label_field = Field(sequential=False, batch_first=True) label_field.vocab = Vocab(vocab_count['relation'], specials=[]) reltype_field = Field(batch_first=True, sequential=False) reltype_field.vocab = Vocab(vocab_count['rel_type']) fields_dict = { 'text': [('text', text_field), ('chars', char_field)], 'pos1': ('pos1', pos1_field), 'pos2': ('pos2', pos2_field), 'pos1_rel': ('pos1_rel', pos1_rel_field), 'pos2_rel': ('pos2_rel', pos2_rel_field), 'relation': ('relation', label_field), 'rel_type': ('rel_type', reltype_field) } return fields_dict