Exemplo n.º 1
0
 def build_field_vocab(cls,
                       field: Field,
                       counter: Counter,
                       size_multiple: int = 1,
                       **kwargs):
     # PN: original name was _build_field_vocab
     # this is basically copy-pasted from torchtext.
     all_specials = [
         field.unk_token, field.pad_token, field.init_token,
         field.eos_token, "0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
         "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
         "21", "22", "23", "24", "25", "26", "27", "28", "29", "30", "31",
         "32", "33", "34", "35", "36", "37", "38", "39", "40", "41", "42",
         "43", "44", "45", "46", "47", "48", "49", "50", "51", "52", "53",
         "54", "55", "56", "57", "58", "59", "60", "61", "62", "63", "64",
         "65", "66", "67", "68", "69", "70", "71", "72", "73", "74", "75",
         "76", "77", "78", "79", "80", "81", "82", "83", "84", "85", "86",
         "87", "88", "89", "90", "91", "92", "93", "94", "95", "96", "97",
         "98", "99", "100", "101", "102", "103", "104", "105", "106", "107",
         "108", "109", "110", "111", "112", "113", "114", "115", "116",
         "117", "118", "119", "120", "121", "122", "123", "124", "125",
         "126", "127"
     ]
     specials = [tok for tok in all_specials if tok is not None]
     field.vocab = field.vocab_cls(counter, specials=specials, **kwargs)
     if size_multiple > 1:
         cls.pad_vocab_to_multiple(field.vocab, size_multiple)
     return
Exemplo n.º 2
0
def load_data_dict(experiment_name,
                   langs,
                   corpora_type,
                   args,
                   device,
                   src_field=None,
                   trg_field=None):
    if src_field == None or trg_field == None:
        src_field = Field(tokenize=str.split,
                          unk_token=UNK_WORD,
                          pad_token=PAD_WORD,
                          init_token=BOS_WORD,
                          eos_token=EOS_WORD)
        trg_field = Field(tokenize=str.split,
                          unk_token=UNK_WORD,
                          pad_token=PAD_WORD,
                          init_token=BOS_WORD,
                          eos_token=EOS_WORD)
        fields = (src_field, trg_field)
        print('Loading src vocab')
        src_vocab = load_vocab(get_vocab_path(experiment_name, langs[0]))
        src_field.vocab = src_field.vocab_cls(
            src_vocab, specials=[UNK_WORD, PAD_WORD, BOS_WORD, EOS_WORD])
        print('Loading trg vocab')
        trg_vocab = load_vocab(get_vocab_path(experiment_name, langs[1]))
        trg_field.vocab = trg_field.vocab_cls(
            trg_vocab, specials=[UNK_WORD, PAD_WORD, BOS_WORD, EOS_WORD])
        args.src_pad_idx = src_field.vocab.stoi[PAD_WORD]
        args.trg_pad_idx = trg_field.vocab.stoi[PAD_WORD]
        args.trg_bos_idx = trg_field.vocab.stoi[BOS_WORD]
        args.trg_eos_idx = trg_field.vocab.stoi[EOS_WORD]
        args.src_vocab_size = len(src_field.vocab)
        args.trg_vocab_size = len(trg_field.vocab)

    print('Loading data')
    data, total_tokens = load_data(experiment_name=experiment_name,
                                   langs=langs,
                                   fields=fields,
                                   batch_size=args.batch_size,
                                   device=device,
                                   corpora_type=corpora_type,
                                   reduce_size=args.data_reduce_size)
    return data, total_tokens, src_field, trg_field
Exemplo n.º 3
0
 def build_field_vocab(cls,
                       field: Field,
                       counter: Counter,
                       size_multiple: int = 1,
                       **kwargs) -> NoReturn:
     # PN: original name was _build_field_vocab
     # this is basically copy-pasted from torchtext.
     all_specials = [
         field.unk_token, field.pad_token, field.init_token, field.eos_token
     ]
     specials = [tok for tok in all_specials if tok is not None]
     field.vocab = field.vocab_cls(counter, specials=specials, **kwargs)
     if size_multiple > 1:
         cls.pad_vocab_to_multiple(field.vocab, size_multiple)
     return
Exemplo n.º 4
0
    def __init__(self,
                 module_name,
                 train_bs,
                 eval_bs,
                 device,
                 vocab=None,
                 base_folder=None,
                 train_name=None,
                 eval_name=None,
                 x_ext=None,
                 y_ext=None,
                 tokens=None,
                 specials=None,
                 tokenizer=None,
                 sort_within_batch=None,
                 shuffle=None):

        self.module_name = module_name

        # split_chars = lambda x: list("".join(x.split()))
        split_chars = lambda x: list(x)  # keeps whitespaces

        if not tokenizer:
            tokenizer = split_chars

        # NOTE: on Jul-20-2020, removed fix_length=200 since it forces
        # all batches to be of size (batch_size, 200) which
        # really wastes GPU memory
        source = Field(tokenize=tokenizer,
                       init_token='<sos>',
                       eos_token='<eos>',
                       batch_first=True)

        target = Field(tokenize=tokenizer,
                       init_token='<sos>',
                       eos_token='<eos>',
                       batch_first=True)

        base_folder = os.path.expanduser(base_folder)

        folder = os.path.join(base_folder, module_name)

        # fix slashes
        folder = os.path.abspath(folder)

        print("loading FULL datasets from folder={}".format(folder))

        train_dataset, eval_dataset, _ = TranslationDataset.splits(
            path=folder,
            root=folder,
            exts=(x_ext, y_ext),
            fields=(source, target),
            train=train_name,
            validation=eval_name,
            test=eval_name)

        if vocab:
            print("Setting vocab to prebuilt file...")
            source.vocab = vocab
            target.vocab = vocab
        elif tokens:
            print("Building vocab from tokens...")
            #source.build_vocab(tokens, specials)
            counter = Counter(tokens)
            source.vocab = source.vocab_cls(counter, specials=specials)
            target.vocab = source.vocab
        else:
            print("Building vocab from TRAIN and EVAL datasets...")
            source.build_vocab(train_dataset, eval_dataset)
            target.vocab = source.vocab

        print("Creating iterators ...")
        do_shuffle = True if shuffle is None else shuffle
        train_iterator = Iterator(dataset=train_dataset,
                                  batch_size=train_bs,
                                  train=True,
                                  repeat=True,
                                  shuffle=do_shuffle,
                                  sort_within_batch=sort_within_batch,
                                  device=device)

        eval_iterator = Iterator(dataset=eval_dataset,
                                 batch_size=eval_bs,
                                 train=False,
                                 repeat=False,
                                 shuffle=False,
                                 sort_within_batch=sort_within_batch,
                                 device=device)

        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset

        self.train_iterator = train_iterator
        self.eval_iterator = eval_iterator

        self.source = source
        self.target = target