Exemplo n.º 1
0
def load_dataset(batch_size):
    '''
    	load data sets.
    '''

    Lang1 = Field(include_lengths=True, init_token='<sos>', eos_token='<eos>')
    Lang2 = Field(include_lengths=True, init_token='<sos>', eos_token='<eos>')

    train = TranslationDataset(path='data/40w/train',
                               exts=('.ch', '.en'),
                               fields=(Lang1, Lang2))
    val = TranslationDataset(path='data/40w/valid',
                             exts=('.ch', '.en'),
                             fields=(Lang1, Lang2))
    test = TranslationDataset(path='data/40w/test',
                              exts=('.ch', '.en'),
                              fields=(Lang1, Lang2))

    Lang1.build_vocab(train.src, max_size=30000)
    Lang2.build_vocab(train.trg, max_size=30000)

    train_iter, val_iter, test_iter = BucketIterator.splits(
        (train, val, test), batch_size=batch_size, repeat=False)

    return train_iter, val_iter, test_iter, Lang1, Lang2
Exemplo n.º 2
0
def build_scan(split, batch_size, device):
    # Get paths and filenames of each partition of split
    if split == 'simple':
        path = 'data/scan/simple/'
    elif split == 'addjump':
        path = 'data/scan/addjump/'
    else:
        assert split not in ['simple','addjump'], "Unknown split"
    train_path = os.path.join(path,'train')
    dev_path = os.path.join(path,'dev')
    test_path = os.path.join(path,'test')
    exts = ('.src','.trg')

    # Fields for source (SRC) and target (TRG) sequences
    SRC = Field(init_token='<sos>',eos_token='<eos>')
    TRG = Field(init_token='<sos>',eos_token='<eos>')
    fields = (SRC,TRG)

    # Build datasets
    train_ = TranslationDataset(train_path,exts,fields)
    dev_ = TranslationDataset(dev_path,exts,fields)
    test_ = TranslationDataset(test_path,exts,fields)

    # Build vocabs: fields ensure same vocab used for each partition
    SRC.build_vocab(train_)
    TRG.build_vocab(train_)

    # BucketIterator ensures similar sequence lengths to minimize padding
    train, dev, test = BucketIterator.splits((train_, dev_, test_),
        batch_size = batch_size, device = device)

    return SRC, TRG, train, dev, test
Exemplo n.º 3
0
def prepare_dataloaders_from_bpe_files(opt, device):
    batch_size = opt.batch_size
    MIN_FREQ = 2
    if not opt.embs_share_weight:
        raise

    data = pickle.load(open(opt.data_pkl, 'rb'))
    MAX_LEN = data['settings'].max_len
    field = data['vocab']
    fields = (field, field)

    def filter_examples_with_length(x):
        return len(vars(x)['src']) <= MAX_LEN and len(
            vars(x)['trg']) <= MAX_LEN

    train = TranslationDataset(fields=fields,
                               path=opt.train_path,
                               exts=('.src', '.trg'),
                               filter_pred=filter_examples_with_length)
    val = TranslationDataset(fields=fields,
                             path=opt.val_path,
                             exts=('.src', '.trg'),
                             filter_pred=filter_examples_with_length)

    opt.max_token_seq_len = MAX_LEN + 2
    opt.src_pad_idx = opt.trg_pad_idx = field.vocab.stoi[Constants.PAD_WORD]
    opt.src_vocab_size = opt.trg_vocab_size = len(field.vocab)

    train_iterator = BucketIterator(train,
                                    batch_size=batch_size,
                                    device=device,
                                    train=True)
    val_iterator = BucketIterator(val, batch_size=batch_size, device=device)
    return train_iterator, val_iterator
Exemplo n.º 4
0
 def make_dataset(self, exts):
     return TranslationDataset(
         path=self.paths['path'],
         exts=exts,
         fields=[('src', self.field), ('tgt', self.field)],
         filter_pred=lambda x: len(x.src) <= self.max_seq_len and len(
             x.tgt) <= self.max_seq_len)
Exemplo n.º 5
0
def load_dataset(batch_size):
    spacy_de = spacy.load('de')
    spacy_en = spacy.load('en')
    url = re.compile('(<url>.*</url>)')

    def tokenize_de(text):
        return [tok.text for tok in spacy_de.tokenizer(url.sub('@URL@', text))]

    def tokenize_en(text):
        return [tok.text for tok in spacy_en.tokenizer(url.sub('@URL@', text))]

    DE = Field(tokenize=tokenize_de, include_lengths=True,
               init_token='<sos>', eos_token='<eos>')
    EN = Field(tokenize=tokenize_en, include_lengths=True,
               init_token='<sos>', eos_token='<eos>')
    #train, val, test = Multi30k.splits(exts=('.de', '.en'), fields=(DE, EN))
    
    train, val, test = TranslationDataset.splits(      
          path = '.data/multi30k',  
          exts = ['.de', '.en'],   
          fields = [('src', DE), ('trg', EN)],
          train = 'train', 
          validation = 'val', 
          test = 'test2016')
    DE.build_vocab(train.src, min_freq=2)
    EN.build_vocab(train.trg, max_size=10000)
    train_iter, val_iter, test_iter = BucketIterator.splits(
            (train, val, test), batch_size=batch_size, repeat=False)
    return train_iter, val_iter, test_iter, DE, EN
Exemplo n.º 6
0
def load_datasets(dataset_path, dataset_names, translate_pair, extentions, fields):
    final_datasets = []
    exts = [".%s"%x for x in extentions]
    for dataset_name in dataset_names:
        final_datasets.append(TranslationDataset(path=os.path.join(dataset_path, translate_pair, dataset_name), exts=exts, fields=[fields[0], fields[1]]))
    
    return final_datasets
def get_data(path, batch_size, load):
    field_src, field_dst = initialize_field(path + 'field.src',
                                            path + 'field.dst', load)

    print("Loading Training Set... ")
    train_set = TranslationDataset(path=path + ('.train.'),
                                   exts=('src', 'dst'),
                                   fields=(field_src, field_dst))

    print("Loading Validation Set... ")
    valid_set = TranslationDataset(path=path + ('.valid.'),
                                   exts=('src', 'dst'),
                                   fields=(field_src, field_dst))

    print("Loading Test Set... ")
    test_set = TranslationDataset(path=path + ('.test.'),
                                  exts=('src', 'dst'),
                                  fields=(field_src, field_dst))

    # Build vocabulary. Train, validation and test sets share the same volcabulary
    if load == False:
        print("Build vocabulary... ")
        field_src.build_vocab(valid_set)
        field_dst.build_vocab(valid_set)
        save_data(path + '.field.src', field_src)
        save_data(path + '.field.dst', field_dst)

    # Initialize dataloaders
    print("Creating Iterators... ")
    train_iter = BucketIterator(
        dataset=train_set,
        batch_size=batch_size,
        sort_key=lambda x: torchtext.data.interleave_keys(
            len(x.field_src), len(x.field_dst)))
    valid_iter = BucketIterator(
        dataset=valid_set,
        batch_size=batch_size,
        sort_key=lambda x: torchtext.data.interleave_keys(
            len(x.field_src), len(x.field_dst)))
    test_iter = BucketIterator(
        dataset=test_set,
        batch_size=batch_size,
        sort_key=lambda x: torchtext.data.interleave_keys(
            len(x.field_src), len(x.field_dst)))
    return train_iter, valid_iter, test_iter, field_src, field_dst
def get_data(args):
    # batch
    batch_size = args.batch
    device = "cuda" if (torch.cuda.is_available() and args.use_cuda) else "cpu"

    # set up fields
    src = Field(
        sequential=True,
        tokenize=str.split,
        use_vocab=True,
        lower=True,
        include_lengths=False,
        fix_length=args.max_length,  # fix max length
        batch_first=True)
    trg = Field(
        sequential=True,
        tokenize=str.split,
        use_vocab=True,
        init_token='<s>',
        eos_token='</s>',
        lower=True,
        fix_length=args.max_length,  # fix max length
        batch_first=True)

    print('set up fields ... done')

    if args.data_type == "koen":

        train, valid, test = TranslationDataset.splits(('.ko', '.en'),
                                                       (src, trg),
                                                       train='train',
                                                       validation='valid',
                                                       test='test',
                                                       path=args.root_dir)

        # build the vocabulary
        src.build_vocab(train.src, min_freq=args.min_freq)
        trg.build_vocab(train.trg, min_freq=args.min_freq)

        # save the voabulary
        src_vocabs = src.vocab.stoi
        trg_vocabs = trg.vocab.stoi

        with open('./src_vocabs.pkl', 'wb') as f:
            pickle.dump(src_vocabs, f, pickle.HIGHEST_PROTOCOL)
        with open('./trg_vocabs.pkl', 'wb') as f:
            pickle.dump(trg_vocabs, f, pickle.HIGHEST_PROTOCOL)

    else:
        assert False, "Please Insert data_type"

    train_iter, valid_iter, test_iter = BucketIterator.splits(
        (train, valid, test), batch_sizes=([batch_size] * 3), device=device)

    return (src, trg), (train, valid, test), (train_iter, valid_iter,
                                              test_iter)
Exemplo n.º 9
0
    def create(cls, config):

        src_field = Field(init_token='<sos>',
                          eos_token='<eos>',
                          pad_token='<pad>',
                          include_lengths=True)

        trg_field = Field(init_token='<sos>',
                          eos_token='<eos>',
                          pad_token='<pad>',
                          lower=True,
                          include_lengths=True)

        train = TranslationDataset(path=config.train_prefix,
                                   exts=config.exts,
                                   fields=(src_field, trg_field))
        valid = TranslationDataset(path=config.valid_prefix,
                                   exts=config.exts,
                                   fields=(src_field, trg_field))

        test = TranslationDataset(path=config.test_prefix,
                                  exts=config.exts,
                                  fields=(src_field, trg_field))

        train_it, valid_it, test_it = BucketIterator.splits(
            [train, valid, test],
            batch_sizes=config.batch_sizes,
            sort_key=TranslationDataset.sort_key,
            device=-1)

        src_field.build_vocab(train, min_freq=10)
        trg_field.build_vocab(train, min_freq=10)

        src_voc = src_field.vocab
        trg_voc = trg_field.vocab

        model = Seq2Seq.create(src_voc, trg_voc, config)

        if config.use_cuda:
            model = model.cuda()

        return Trainer(model, train_it, valid_it, test_it, config.valid_step,
                       config.checkpoint_path, config.pool_size)
Exemplo n.º 10
0
def load_data(path_train, path_test, in_ext, out_ext, model_dir, batch_size=1):
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	tokenizer = lambda x: x.split()
	lowercase = True
	# TODO: Whether we need init_token=BOS_TOKEN, is still questionable. With POS tags,
	#  it was needed, for lstm didnt matter, for gru attn model, absolutely not.
	src = data.Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN,
						   pad_token=PAD_TOKEN, tokenize=tokenizer,
						   batch_first=False, lower=lowercase,
						   unk_token=UNK_TOKEN,
						   include_lengths=False)
	trg = data.Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN,
						   pad_token=PAD_TOKEN, tokenize=tokenizer,
						   unk_token=UNK_TOKEN,
						   batch_first=False, lower=lowercase,
						   include_lengths=False)
	train_data = TranslationDataset(path=path_train,
										exts=("." + in_ext, "." + out_ext),
										fields=(src, trg))
	test_data = TranslationDataset(path=path_test,
										exts=("." + in_ext, "." + out_ext),
										fields=(src, trg))
	# build the vocabulary
	src.build_vocab(train_data)
	trg.build_vocab(train_data)
	print_vocab(src, trg)
	with open(os.path.join(model_dir, "src.Field"), "wb") as f:
		dill.dump(src, f)
	with open(os.path.join(model_dir, "trg.Field"), "wb") as f:
		dill.dump(trg, f)
	# Make sure we use, sort=False else accuracy drops
	train_iter = data.BucketIterator(
			repeat=False, sort=False, dataset = train_data,
			batch_size=batch_size, sort_within_batch=True,
			sort_key=lambda x: len(x.src), shuffle=True, train=True, device=device)
	test_iter = data.BucketIterator(
			repeat=False, sort=False, dataset = test_data,
			batch_size=1, sort_within_batch=False,
			sort_key=lambda x: len(x.src), shuffle=False, train=False, device=device)
	#pdb.set_trace()
	return train_iter, test_iter, src, trg
Exemplo n.º 11
0
 def _make_test_set(test_path: str, src_lang: str, trg_lang: str) -> Optional[Dataset]:
     if test_path is not None:
         if os.path.isfile(test_path + "." + trg_lang):
             return TranslationDataset(path=test_path,
                                       exts=("." + src_lang, "." + trg_lang),
                                       fields=(fields['src'][src_lang], fields['trg'][trg_lang]))
         else:
             return MonoDataset(path=test_path,
                                ext="." + src_lang,
                                field=fields['src'][src_lang])
     else:
         return None
Exemplo n.º 12
0
    def __init__(self, module_name, train_bs, eval_bs, device, log):
        self.module_name = module_name

        # split_chars = lambda x: list("".join(x.split()))
        split_chars = lambda x: list(x)  # keeps whitespaces

        source = Field(tokenize=split_chars,
                       init_token='<sos>',
                       eos_token='<eos>',
                       batch_first=True)

        target = Field(tokenize=split_chars,
                       init_token='<sos>',
                       eos_token='<eos>',
                       batch_first=True)

        log("Loading FULL datasets ...")
        folder = os.path.join(DATASET_TARGET_DIR, module_name)
        train_dataset, eval_dataset, _ = TranslationDataset.splits(
            path=folder,
            root=folder,
            exts=(INPUTS_FILE_ENDING, TARGETS_FILE_ENDING),
            fields=(source, target),
            train=TRAIN_FILE_NAME,
            validation=EVAL_FILE_NAME,
            test=EVAL_FILE_NAME)

        log("Building vocab ...")
        source.build_vocab(train_dataset)
        target.vocab = source.vocab

        log("Creating iterators ...")
        train_iterator = Iterator(dataset=train_dataset,
                                  batch_size=train_bs,
                                  train=True,
                                  repeat=True,
                                  shuffle=True,
                                  device=device)

        eval_iterator = Iterator(dataset=eval_dataset,
                                 batch_size=eval_bs,
                                 train=False,
                                 repeat=False,
                                 shuffle=False,
                                 device=device)

        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.train_iterator = train_iterator
        self.eval_iterator = eval_iterator
        self.source = source
        self.target = target
Exemplo n.º 13
0
Arquivo: pre.py Projeto: amlox2019/AML
def create_data(data, lang):
    source_text = Field(tokenize=MosesTokenizer('en'),
                        init_token='<sos>',
                        eos_token='<eos>',
                        lower=True,
                        pad_token='<pad>',
                        unk_token='<unk>')
    target_text = Field(tokenize=MosesTokenizer(lang),
                        init_token='<sos>',
                        eos_token='<eos>',
                        lower=True,
                        pad_token='<pad>',
                        unk_token='<unk>')

    train = TranslationDataset(path=data,
                               exts=('.en', '.' + lang),
                               fields=(source_text, target_text))

    # Load the word vectors from the embedding directory
    print('Loading en word vectors')
    en_vectors = Vectors(name='cc.en.300.vec', cache=emb_dir)
    print('Loaded.')
    print('Loading {} word vectors'.format(lang))
    if lang == 'fr':
        target_vectors = Vectors(name='cc.fr.300.vec', cache=emb_dir)
    elif lang == 'de':
        target_vectors = Vectors(name='embed_tweets_de_100D_fasttext',
                                 cache=emb_dir)
    else:
        raise NotImplementedError
    print('Loaded.')

    # Build vocabulary
    print('Building en vocab')
    source_text.build_vocab(train,
                            max_size=15000,
                            min_freq=1,
                            vectors=en_vectors)
    print('Building {} vocab'.format(lang))
    target_text.build_vocab(train,
                            max_size=15000,
                            min_freq=1,
                            vectors=target_vectors)
    #source_text.build_vocab(train, min_freq = 30000, vectors="glove.6B.200d")
    #target_text.build_vocab(train, min_freq = 30000, vectors="glove.6B.200d")

    pad_idx = target_text.vocab.stoi['<pad>']
    print('pad_idx', pad_idx)
    eos_idx = target_text.vocab.stoi['<eos>']
    print('eos_idx', eos_idx)

    return train, source_text, target_text
Exemplo n.º 14
0
def prepare_dataloaders(opt):
    en = textdata.Field(tokenize='spacy',
                        tokenizer_language='en',
                        init_token=Constants.BOS_WORD,
                        eos_token=Constants.EOS_WORD,
                        pad_token=Constants.PAD_WORD,
                        unk_token=Constants.UNK_WORD)
    sql_tokenizer = lambda x: x.split(Constants.SQL_SEPARATOR)
    sql = textdata.Field(tokenize=sql_tokenizer,
                         init_token=Constants.BOS_WORD,
                         eos_token=Constants.EOS_WORD,
                         pad_token=Constants.PAD_WORD,
                         unk_token=Constants.UNK_WORD)
    fields = [('en', en), ('sql', sql)]

    ds_train = TranslationDataset(path=opt.train_data,
                                  exts=('.en', '.sql'),
                                  fields=fields)
    ds_validation = TranslationDataset(path=opt.validation_data,
                                       exts=('.en', '.sql'),
                                       fields=fields)

    en.build_vocab(ds_train, max_size=80000)
    sql.build_vocab(ds_train, max_size=40000)

    train_iter = textdata.BucketIterator(dataset=ds_train,
                                         batch_size=opt.batch_size,
                                         sort_key=lambda x: len(x.en),
                                         device=opt.device)
    validation_iter = textdata.BucketIterator(dataset=ds_validation,
                                              batch_size=opt.batch_size,
                                              sort_key=lambda x: len(x.en),
                                              device=opt.device)


    return BatchWrapper(train_iter, fields, device=opt.device),\
           BatchWrapper(validation_iter, fields, device=opt.device),\
           en, sql
def load_dataset(batch_size, device):
    """
    Load the dataset from the files into iterator and initialize the vocabulary
    :param batch_size
    :param device
    :return: source and data iterators
    """
    source = Field(tokenize=tokenize_en,
                   init_token='<sos>',
                   eos_token='<eos>',
                   lower=True)

    train_data, valid_data, test_data = TranslationDataset.splits(
        path=DATA_FOLDER,
        exts=(POSITIVE_FILE_EXTENSION, NEGATIVE_FILE_EXTENSION),
        fields=(source, source))
    source.build_vocab(train_data, min_freq=5)
    return source, BucketIterator.splits((train_data, valid_data, test_data),
                                         shuffle=True,
                                         batch_size=batch_size,
                                         device=device)
Exemplo n.º 16
0
def get_data(path='data/'):
    SRC = Field(tokenize=tokenize_cn,
                init_token='<sos>',
                eos_token='<eos>',
                pad_token='<pad>',
                unk_token='<unk>',
                lower=True)
    TRG = Field(tokenize=tokenize_en,
                init_token='<sos>',
                eos_token='<eos>',
                pad_token='<pad>',
                unk_token='<unk>',
                lower=True)

    train_data, valid_data, test_data = TranslationDataset.splits(
        path=path,
        train='train',
        validation='val',
        test='test',
        exts=('.cn', '.en'),
        fields=(SRC, TRG))

    print("train: {}".format(len(train_data.examples)))
    print("valid: {}".format(len(valid_data.examples)))
    print("test: {}".format(len(test_data.examples)))

    SRC.build_vocab(train_data, min_freq=params.MIN_FREQ)
    TRG.build_vocab(train_data, min_freq=params.MIN_FREQ)

    print("源语言词表大小: {}".format(len(SRC.vocab)))
    print("目标语言词表大小: {}".format(len(TRG.vocab)))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
        (train_data, valid_data, test_data),
        batch_size=params.BATCH_SIZE,
        device=device)

    return train_iterator, valid_iterator, test_iterator, SRC, TRG
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-raw_dir', required=True)
    parser.add_argument('-data_dir', required=True)
    parser.add_argument('-codes', required=True)
    parser.add_argument('-save_data', required=True)
    parser.add_argument('-prefix', required=True)
    parser.add_argument('-max_len', type=int, default=100)
    parser.add_argument('--symbols', '-s', type=int, default=32000, help="Vocabulary size")
    parser.add_argument(
        '--min-frequency', type=int, default=6, metavar='FREQ',
        help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))')
    parser.add_argument('--dict-input', action="store_true",
                        help="If set, input file is interpreted as a dictionary where each line contains a word-count pair")
    parser.add_argument(
        '--separator', type=str, default='@@', metavar='STR',
        help="Separator between non-final subword units (default: '%(default)s'))")
    parser.add_argument('--total-symbols', '-t', action="store_true")
    opt = parser.parse_args()

    # Create folder if needed.
    mkdir_if_needed(opt.raw_dir)
    mkdir_if_needed(opt.data_dir)

    # Download and extract raw data.
    raw_train = get_raw_files(opt.raw_dir, _TRAIN_DATA_SOURCES)
    raw_val = get_raw_files(opt.raw_dir, _VAL_DATA_SOURCES)
    raw_test = get_raw_files(opt.raw_dir, _TEST_DATA_SOURCES)

    # Merge files into one.
    train_src, train_trg = compile_files(opt.raw_dir, raw_train, opt.prefix + '-train')
    val_src, val_trg = compile_files(opt.raw_dir, raw_val, opt.prefix + '-val')
    test_src, test_trg = compile_files(opt.raw_dir, raw_test, opt.prefix + '-test')

    # Build up the code from training files if not exist
    opt.codes = os.path.join(opt.data_dir, opt.codes)
    if not os.path.isfile(opt.codes):
        sys.stderr.write(f"Collect codes from training data and save to {opt.codes}.\n")
        learn_bpe(raw_train['src'] + raw_train['trg'], opt.codes, opt.symbols, opt.min_frequency, True)
    sys.stderr.write(f"BPE codes prepared.\n")

    sys.stderr.write(f"Build up the tokenizer.\n")
    with codecs.open(opt.codes, encoding='utf-8') as codes:
        bpe = BPE(codes, separator=opt.separator)

    sys.stderr.write(f"Encoding ...\n")
    encode_files(bpe, train_src, train_trg, opt.data_dir, opt.prefix + '-train')
    encode_files(bpe, val_src, val_trg, opt.data_dir, opt.prefix + '-val')
    encode_files(bpe, test_src, test_trg, opt.data_dir, opt.prefix + '-test')
    sys.stderr.write(f"Done.\n")

    field = torchtext.data.Field(
        tokenize=str.split,
        lower=True,
        pad_token=Constants.PAD_WORD,
        init_token=Constants.BOS_WORD,
        eos_token=Constants.EOS_WORD)

    fields = (field, field)

    MAX_LEN = opt.max_len

    def filter_examples_with_length(x):
        return len(vars(x)['src']) <= MAX_LEN and len(vars(x)['trg']) <= MAX_LEN

    enc_train_files_prefix = opt.prefix + '-train'
    train = TranslationDataset(
        fields=fields,
        path=os.path.join(opt.data_dir, enc_train_files_prefix),
        exts=('.src', '.trg'),
        filter_pred=filter_examples_with_length)

    from itertools import chain
    field.build_vocab(chain(train.src, train.trg), min_freq=2)

    data = {'settings': opt, 'vocab': field, }
    opt.save_data = os.path.join(opt.data_dir, opt.save_data)

    print('[Info] Dumping the processed data to pickle file', opt.save_data)
    pickle.dump(data, open(opt.save_data, 'wb'))
Exemplo n.º 18
0
from torchtext.data import Field, BucketIterator
from torchtext.datasets import TranslationDataset
from torch import nn
import torch
import torch.nn.functional as F

device = torch.device("cuda")

Lang1 = Field(eos_token='<eos>')
Lang2 = Field(init_token='<sos>', eos_token='<eos>')

train = TranslationDataset(path='../Datasets/MT_data/',
                           exts=('eng-fra.train.fr', 'eng-fra.train.en'),
                           fields=[('Lang1', Lang1), ('Lang2', Lang2)])

train_iter, val_iter, test_iter = BucketIterator.splits((train, train, train),
                                                        batch_size=256,
                                                        repeat=False)
Lang1.build_vocab(train)
Lang2.build_vocab(train)

# for i, train_batch in enumerate(train_iter):
#     print('Lang1  : \n', [Lang1.vocab.itos[x] for x in train_batch.Lang1.data[:,0]])
#     print('Lang2 : \n', [Lang2.vocab.itos[x] for x in train_batch.Lang2.data[:,0]])
print(Lang1.vocab.stoi)
print(Lang2.vocab.itos)

import json
with open('encoder_vocab.json', 'w', encoding='utf8') as f:
    json.dump(Lang1.vocab.stoi, f, ensure_ascii=False)
with open('decoder_vocab.json', 'w', encoding='utf8') as f:
Exemplo n.º 19
0
def load_dataset(args):
    def tokenzie_zhcha(text):
        #return [tok for tok in re.sub('\s','',text).strip()]
        return [tok for tok in text.strip()]

    def tokenzie_zhword(text):
        return [tok for tok in text.strip().split()]

    def tokenzie_ticha(text):
        return [tok for tok in text.strip().split()]

    def tokenzie_tiword(text):
        return [tok for tok in text.strip().split()]

    ZH_CHA = Field(tokenize=tokenzie_zhcha,
                   include_lengths=True,
                   init_token='<sos>',
                   eos_token='<eos>')

    ZH_WORD = Field(tokenize=tokenzie_zhword,
                    include_lengths=True,
                    init_token='<sos>',
                    eos_token='<eos>')

    Ti_CHA = Field(tokenize=tokenzie_ticha,
                   include_lengths=True,
                   init_token='<sos>',
                   eos_token='<eos>')

    Ti_WORD = Field(tokenize=tokenzie_tiword,
                    include_lengths=True,
                    init_token='<sos>',
                    eos_token='<eos>')

    #pdb.set_trace()

    #According to training mode, load data
    if args.mode == 'ctc':
        exts = (args.extension.split()[0], args.extension.split()[1])
        train, val, test = Trans.splits(path=args.path,
                                        exts=exts,
                                        fields=(Ti_CHA, Ti_WORD),
                                        train=args.train,
                                        validation=args.valid,
                                        test=args.test)

        Ti_CHA.build_vocab(train.src)
        Ti_WORD.build_vocab(train.trg)

        train_iter, val_iter, test_iter = BucketIterator.splits(
            (train, val, test), batch_size=args.batch_size, repeat=False)
        return train_iter, val_iter, test_iter, Ti_CHA, Ti_WORD

    elif args.mode == 'nmt':
        exts = (args.extension.split()[0], args.extension.split()[1])
        train, val, test = Trans.splits(path=args.path,
                                        exts=exts,
                                        fields=(Ti_WORD, ZH_WORD),
                                        train=args.train,
                                        validation=args.valid,
                                        test=args.test)

        Ti_WORD.build_vocab(train.src, max_size=50000)
        ZH_WORD.build_vocab(train.trg, max_size=50000)

        train_iter, val_iter, test_iter = BucketIterator.splits(
            (train, val, test), batch_size=args.batch_size, repeat=False)
        return train_iter, val_iter, test_iter, Ti_WORD, ZH_WORD

    elif args.mode == 'nmt_char':
        exts = (args.extension.split()[0], args.extension.split()[1])
        train, val, test = Trans.splits(path=args.path,
                                        exts=exts,
                                        fields=(Ti_CHA, ZH_CHA),
                                        train=args.train,
                                        validation=args.valid,
                                        test=args.test)

        Ti_CHA.build_vocab(train.src)
        ZH_CHA.build_vocab(train.trg)

        train_iter, val_iter, test_iter = BucketIterator.splits(
            (train, val, test), batch_size=args.batch_size, repeat=False)
        return train_iter, val_iter, test_iter, Ti_CHA, ZH_CHA

    elif args.mode == 'combine':
        exts = (args.extension.split()[0], args.extension.split()[1])
        train, val, test = Trans.splits(path=args.path,
                                        exts=exts,
                                        fields=(Ti_CHA, ZH_WORD),
                                        train=args.train,
                                        validation=args.valid,
                                        test=args.test)

        Ti_CHA.build_vocab(train.src)
        ZH_WORD.build_vocab(train.trg, max_size=50000)

        train_iter, val_iter, test_iter = BucketIterator.splits(
            (train, val, test), batch_size=args.batch_size, repeat=False)
        return train_iter, val_iter, test_iter, Ti_CHA, ZH_WORD

    elif args.mode == 'refine_ctc':
        exts = (args.extension.split()[0], args.extension.split()[1])
        train, val, test = Trans.splits(path=args.path,
                                        exts=exts,
                                        fields=(Ti_CHA, Ti_WORD),
                                        train=args.train,
                                        validation=args.valid,
                                        test=args.test)

        Ti_CHA.build_vocab(train.src)
        Ti_WORD.build_vocab(train.trg, max_size=50000)

        train_iter, val_iter, test_iter = BucketIterator.splits(
            (train, val, test), batch_size=args.batch_size, repeat=False)
        return train_iter, val_iter, test_iter, Ti_CHA, Ti_WORD

    elif args.mode == 'update_twoLoss':
        exts = (args.extension.split()[0], args.extension.split()[1],
                args.extension.split()[2])
        train, val, test, = mydataset.splits(path=args.path,
                                             exts=exts,
                                             fields=(Ti_CHA, ZH_WORD, Ti_WORD),
                                             train=args.train,
                                             validation=args.valid,
                                             test=args.test)
        Ti_CHA.build_vocab(train.src)
        ZH_WORD.build_vocab(train.trg, max_size=50000)
        Ti_WORD.build_vocab(train.ctc, max_size=50000)

        train_iter, val_iter, test_iter = BucketIterator.splits(
            (train, val, test), batch_size=args.batch_size, repeat=False)

        return train_iter, val_iter, test_iter, Ti_CHA, ZH_WORD, Ti_WORD
Exemplo n.º 20
0
import math
import time

SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

SRC = Field(tokenize=None, init_token='<sos>', eos_token='<eos>', lower=True)

TRG = Field(tokenize=None, init_token='<sos>', eos_token='<eos>', lower=True)

myData = TranslationDataset('./E_V/train', ('.en', '.vi'), (SRC, TRG))

train_data, test_data = myData.splits(exts=('.en', '.vi'),
                                      fields=(SRC, TRG),
                                      path="./E_V/",
                                      train='train',
                                      validation=None,
                                      test='tst2012')
vocabData = TranslationDataset('./E_V/vocab', ('.en', '.vi'), (SRC, TRG))
print(f"Number of training examples: {len(train_data.examples)}")
# # print(f"Number of validation examples: {len(valid_data.examples)}")
print(f"Number of testing examples: {len(test_data.examples)}")

SRC.build_vocab(train_data, min_freq=3)
TRG.build_vocab(train_data, min_freq=3)
Exemplo n.º 21
0
def load_data(data_cfg: dict, datasets: list = None)\
        -> (Dataset, Dataset, Optional[Dataset], Vocabulary, Vocabulary):
    """
    Load train, dev and optionally test data as specified in configuration.
    Vocabularies are created from the training set with a limit of `voc_limit`
    tokens and a minimum token frequency of `voc_min_freq`
    (specified in the configuration dictionary).

    The training data is filtered to include sentences up to `max_sent_length`
    on source and target side.

    If you set ``random_train_subset``, a random selection of this size is used
    from the training set instead of the full training set.

    :param data_cfg: configuration dictionary for data
        ("data" part of configuation file)
    :param datasets: list of dataset names to load
    :return:
        - train_data: training dataset
        - dev_data: development dataset
        - test_data: testdata set if given, otherwise None
        - src_vocab: source vocabulary extracted from training data
        - trg_vocab: target vocabulary extracted from training data
    """
    if datasets is None:
        datasets = ["train", "dev", "test"]

    # load data from files
    src_lang = data_cfg["src"]
    trg_lang = data_cfg["trg"]
    train_path = data_cfg.get("train", None)
    dev_path = data_cfg.get("dev", None)
    test_path = data_cfg.get("test", None)

    if train_path is None and dev_path is None and test_path is None:
        raise ValueError('Please specify at least one data source path.')

    level = data_cfg["level"]
    lowercase = data_cfg["lowercase"]
    max_sent_length = data_cfg["max_sent_length"]

    tok_fun = lambda s: list(s) if level == "char" else s.split()

    src_field = data.Field(init_token=None,
                           eos_token=EOS_TOKEN,
                           pad_token=PAD_TOKEN,
                           tokenize=tok_fun,
                           batch_first=True,
                           lower=lowercase,
                           unk_token=UNK_TOKEN,
                           include_lengths=True)

    trg_field = data.Field(init_token=BOS_TOKEN,
                           eos_token=EOS_TOKEN,
                           pad_token=PAD_TOKEN,
                           tokenize=tok_fun,
                           unk_token=UNK_TOKEN,
                           batch_first=True,
                           lower=lowercase,
                           include_lengths=True)

    train_data = None
    if "train" in datasets and train_path is not None:
        logger.info("Loading training data...")
        train_data = TranslationDataset(
            path=train_path,
            exts=("." + src_lang, "." + trg_lang),
            fields=(src_field, trg_field),
            filter_pred=lambda x: len(vars(x)['src']) <= max_sent_length and
            len(vars(x)['trg']) <= max_sent_length)

        random_train_subset = data_cfg.get("random_train_subset", -1)
        if random_train_subset > -1:
            # select this many training examples randomly and discard the rest
            keep_ratio = random_train_subset / len(train_data)
            keep, _ = train_data.split(
                split_ratio=[keep_ratio, 1 - keep_ratio],
                random_state=random.getstate())
            train_data = keep

    src_max_size = data_cfg.get("src_voc_limit", sys.maxsize)
    src_min_freq = data_cfg.get("src_voc_min_freq", 1)
    trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize)
    trg_min_freq = data_cfg.get("trg_voc_min_freq", 1)

    src_vocab_file = data_cfg.get("src_vocab", None)
    trg_vocab_file = data_cfg.get("trg_vocab", None)

    assert (train_data is not None) or (src_vocab_file is not None)
    assert (train_data is not None) or (trg_vocab_file is not None)

    logger.info("Building vocabulary...")
    src_vocab = build_vocab(field="src",
                            min_freq=src_min_freq,
                            max_size=src_max_size,
                            dataset=train_data,
                            vocab_file=src_vocab_file)
    trg_vocab = build_vocab(field="trg",
                            min_freq=trg_min_freq,
                            max_size=trg_max_size,
                            dataset=train_data,
                            vocab_file=trg_vocab_file)

    dev_data = None
    if "dev" in datasets and dev_path is not None:
        logger.info("Loading dev data...")
        dev_data = TranslationDataset(path=dev_path,
                                      exts=("." + src_lang, "." + trg_lang),
                                      fields=(src_field, trg_field))

    test_data = None
    if "test" in datasets and test_path is not None:
        logger.info("Loading test data...")
        # check if target exists
        if os.path.isfile(test_path + "." + trg_lang):
            test_data = TranslationDataset(path=test_path,
                                           exts=("." + src_lang,
                                                 "." + trg_lang),
                                           fields=(src_field, trg_field))
        else:
            # no target is given -> create dataset from src only
            test_data = MonoDataset(path=test_path,
                                    ext="." + src_lang,
                                    field=src_field)
    src_field.vocab = src_vocab
    trg_field.vocab = trg_vocab
    logger.info("Data loaded.")
    return train_data, dev_data, test_data, src_vocab, trg_vocab
Exemplo n.º 22
0
    def __init__(self,
                 data_dir: str,
                 packed: bool,
                 vocab_max_sizes: Tuple[int, int],
                 vocab_min_freqs: Tuple[int, int],
                 batch_sizes: Tuple[int, int, int],
                 test: bool = False):
        print(f"Creating DataLoader for {'testing' if test else 'training'}")

        # Rebuild the vocabs during testin, as the saved can be build from a different config
        if test:
            vocab_exists = False
        else:
            vocab_exists = has_vocabs(data_dir, vocab_max_sizes,
                                      vocab_min_freqs)

        # Define torch text fields for processing text
        if vocab_exists:
            print("Loading fields and vocabs...")
            SRC, TRG = load_vocabs(data_dir, vocab_max_sizes, vocab_min_freqs)
        else:
            print("Building fields...")

            # Include the sentence length for source
            SRC = Field(tokenize=tokenize_diff,
                        init_token='<sos>',
                        eos_token='<eos>',
                        include_lengths=packed,
                        lower=True)

            TRG = Field(tokenize=tokenize_msg,
                        init_token='<sos>',
                        eos_token='<eos>',
                        lower=True)

        print("Loading commit data...")
        train_data, valid_data, test_data = TranslationDataset.splits(
            exts=('.diff', '.msg'),
            train='TrainingSet/train.26208',
            validation='TrainingSet/valid.3000',
            test='TestSet/test.3000',
            fields=(SRC, TRG),
            path=data_dir)

        if not vocab_exists:
            # Build vocabs
            print("Building vocabulary...")
            specials = ['<unk>', '<pad>', '<sos>', '<eos>']
            SRC.build_vocab(train_data,
                            min_freq=vocab_min_freqs[0],
                            max_size=vocab_max_sizes[0],
                            specials=specials)
            TRG.build_vocab(train_data,
                            min_freq=vocab_min_freqs[1],
                            max_size=vocab_max_sizes[1],
                            specials=specials)

            if not test:
                save_vocabs(data_dir, SRC, TRG, vocab_max_sizes,
                            vocab_min_freqs)

        print(f"Number of training examples: {len(train_data.examples)}")
        print(f"Number of validation examples: {len(valid_data.examples)}")
        print(f"Number of testing examples: {len(test_data.examples)}")
        print(
            f"Unique tokens in source (diff) training vocabulary: {len(SRC.vocab)}"
        )
        print(
            f"Unique tokens in target (msg) training vocabulary: {len(TRG.vocab)}"
        )

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Bucketing (minimizes the amount of padding by grouping similar length sentences)
        # Sort the sequences based on their non-padded length
        train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
            (train_data, valid_data, test_data),
            batch_sizes=batch_sizes,
            sort_within_batch=packed,
            sort_key=lambda x: len(x.src) if packed else None,
            device=device)

        super().__init__(train_iterator, valid_iterator, test_iterator, SRC,
                         TRG, tokenize_diff, tokenize_msg)
Exemplo n.º 23
0
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

SRC = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True,
            include_lengths = True)

TRG = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)


train_data = TranslationDataset(path = './data/functions.train', exts = ('.src', '.tgt'), fields = (SRC, TRG))
test_data = TranslationDataset(path = './data/functions.test', exts = ('.src', '.tgt'), fields = (SRC, TRG))
valid_data = TranslationDataset(path = './data/functions.valid', exts = ('.src', '.tgt'), fields = (SRC, TRG))

SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

print('Built the Vocab of SRC and TRG')


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




BATCH_SIZE = 128
Exemplo n.º 24
0
    tensor = torch.LongTensor(numericalized).unsqueeze(1).to(
        device)  #convert to tensor and add batch dimension
    translation_tensor_probs = model(tensor, None, 0).squeeze(
        1)  #pass through model to get translation probabilities
    translation_tensor = torch.argmax(
        translation_tensor_probs,
        1)  #get translation from highest probabilities
    translation = [
        en.vocab.itos[t] for t in translation_tensor
    ][1:]  #we ignore the first token, just like we do in the training loop
    return translation


if __name__ == "__main__":
    train_data = TranslationDataset(
        path="data/",
        exts=["asl_train_processed.txt", "en_train.txt"],
        fields=[asl, en])
    valid_data = TranslationDataset(
        path="data/",
        exts=["asl_val_processed.txt", "en_val.txt"],
        fields=[asl, en])
    test_data = TranslationDataset(
        path="data/",
        exts=["asl_test_processed.txt", "en_test.txt"],
        fields=[asl, en])

    print(f"Number of training examples: {len(train_data.examples)}")
    print(f"Number of validation examples: {len(valid_data.examples)}")
    print(f"Number of testing examples: {len(test_data.examples)}")

    print(vars(train_data.examples[0]))
### (get correct tokenization for each language, append <sos> & <eos>,
### convert all words to lowercase)
SRC_TEXT = Field(tokenize=tokenizer,
                 init_token='<sos>',
                 eos_token='<eos>',
                 lower=False,
                 batch_first=True)
TRG_TEXT = Field(tokenize=tokenizer,
                 init_token='<sos>',
                 eos_token='<eos>',
                 lower=False,
                 batch_first=True)

### Convert text files into TranslationDataset type of torchtext --- this is full data set
train_data = TranslationDataset(path=os.path.join(data_folder, 'train'),
                                exts=('.de', '.en'),
                                fields=(SRC_TEXT, TRG_TEXT))
dev_data = TranslationDataset(path=os.path.join(data_folder, 'dev'),
                              exts=('.de', '.en'),
                              fields=(SRC_TEXT, TRG_TEXT))

print(f"Number of training examples: {len(train_data.examples)}")
print(f"Number of validation examples: {len(dev_data.examples)}")
print(vars(train_data.examples[1]))

### Remove examples whose length of target sentences exceed 40 words
for i, example in enumerate(train_data.examples):
    if len(getattr(train_data.examples[i], 'trg')) > 40:
        del train_data.examples[i]

for i, example in enumerate(dev_data.examples):
Exemplo n.º 26
0
def load_data(cfg):
    """
    Load train, dev and test data as specified in ccnfiguration.

    :param cfg:
    :return:
    """
    # load data from files
    data_cfg = cfg["data"]
    src_lang = data_cfg["src"]
    trg_lang = data_cfg["trg"]
    train_path = data_cfg["train"]
    dev_path = data_cfg["dev"]
    test_path = data_cfg.get("test", None)
    level = data_cfg["level"]
    lowercase = data_cfg["lowercase"]
    max_sent_length = data_cfg["max_sent_length"]

    #pylint: disable=unnecessary-lambda
    if level == "char":
        tok_fun = lambda s: list(s)
    else:  # bpe or word, pre-tokenized
        tok_fun = lambda s: s.split()

    src_field = data.Field(init_token=None,
                           eos_token=EOS_TOKEN,
                           pad_token=PAD_TOKEN,
                           tokenize=tok_fun,
                           batch_first=True,
                           lower=lowercase,
                           unk_token=UNK_TOKEN,
                           include_lengths=True)

    trg_field = data.Field(init_token=BOS_TOKEN,
                           eos_token=EOS_TOKEN,
                           pad_token=PAD_TOKEN,
                           tokenize=tok_fun,
                           unk_token=UNK_TOKEN,
                           batch_first=True,
                           lower=lowercase,
                           include_lengths=True)

    train_data = TranslationDataset(
        path=train_path,
        exts=("." + src_lang, "." + trg_lang),
        fields=(src_field, trg_field),
        filter_pred=lambda x: len(vars(x)['src']) <= max_sent_length and len(
            vars(x)['trg']) <= max_sent_length)

    max_size = data_cfg.get("voc_limit", sys.maxsize)
    min_freq = data_cfg.get("voc_min_freq", 1)
    src_vocab_file = data_cfg.get("src_vocab", None)
    trg_vocab_file = data_cfg.get("trg_vocab", None)

    src_vocab = build_vocab(field="src",
                            min_freq=min_freq,
                            max_size=max_size,
                            dataset=train_data,
                            vocab_file=src_vocab_file)
    trg_vocab = build_vocab(field="trg",
                            min_freq=min_freq,
                            max_size=max_size,
                            dataset=train_data,
                            vocab_file=trg_vocab_file)
    dev_data = TranslationDataset(path=dev_path,
                                  exts=("." + src_lang, "." + trg_lang),
                                  fields=(src_field, trg_field))
    test_data = None
    if test_path is not None:
        # check if target exists
        if os.path.isfile(test_path + "." + trg_lang):
            test_data = TranslationDataset(path=test_path,
                                           exts=("." + src_lang,
                                                 "." + trg_lang),
                                           fields=(src_field, trg_field))
        else:
            # no target is given -> create dataset from src only

            test_data = MonoDataset(path=test_path,
                                    ext="." + src_lang,
                                    field=(src_field))
    src_field.vocab = src_vocab
    trg_field.vocab = trg_vocab
    return train_data, dev_data, test_data, src_vocab, trg_vocab
    spacy_de = nl_core_news_sm.load()
    spacy_en = en_core_web_sm.load()

    SRC = Field(tokenize=tokenize_de,
                init_token='<sos>',
                eos_token='<eos>',
                lower=True)
    TRG = Field(tokenize=tokenize_en,
                init_token='<sos>',
                eos_token='<eos>',
                lower=True)

    train, valid, test = TranslationDataset.splits(path='./data/multi30k/',
                                                   exts=['.de', '.en'],
                                                   fields=[('src', SRC),
                                                           ('trg', TRG)],
                                                   train='train',
                                                   validation='val',
                                                   test='test2016')
    print(vars(train.examples[0]))
    SRC.build_vocab(train, min_freq=2)
    TRG.build_vocab(train, min_freq=2)

    BATCH_SIZE = 128

    train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
        (train, valid, test), batch_size=BATCH_SIZE, repeat=False)

    INPUT_DIM = len(SRC.vocab)
    OUTPUT_DIM = len(TRG.vocab)
    ENC_EMB_DIM = 256
Exemplo n.º 28
0
MAX_LEN = 100

from torchtext.datasets import TranslationDataset, Multi30k
ROOT = './'
Multi30k.download(ROOT)

SRC = data.Field(tokenize=tokenize_de, pad_token=BLANK_WORD)
TGT = data.Field(tokenize=tokenize_en,
                 init_token=BOS_WORD,
                 eos_token=EOS_WORD,
                 pad_token=BLANK_WORD)

(trnset, valset,
 testset) = TranslationDataset.splits(path='./Multi30k/multi30k',
                                      exts=['.en', '.de'],
                                      fields=[('src', SRC), ('trg', TGT)],
                                      test='test2016')

#list(enumerate(testset))

import pandas as pd

df = pd.read_csv("./SQuAD_csv/train_SQuAD.csv", sep=';', header=None)

df = df.iloc[1:, :]
df = df.iloc[:, [1, 2]]

from sklearn.model_selection import train_test_split
train, val = train_test_split(df, test_size=0.1)
train.to_csv("train.csv", index=False)
val.to_csv("val.csv", index=False)
Exemplo n.º 29
0
def load_data(
    data_cfg: dict
) -> (Dataset, Dataset, Optional[Dataset], Vocabulary, Vocabulary):
    """
    Load train, dev and optionally test data as specified in configuration.
    Vocabularies are created from the training set with a limit of `voc_limit`
    tokens and a minimum token frequency of `voc_min_freq`
    (specified in the configuration dictionary).

    The training data is filtered to include sentences up to `max_sent_length`
    on source and target side.

    :param data_cfg: configuration dictionary for data
        ("data" part of configuation file)
    :return:
        - train_data: training dataset
        - dev_data: development dataset
        - test_data: testdata set if given, otherwise None
        - src_vocab: source vocabulary extracted from training data
        - trg_vocab: target vocabulary extracted from training data
    """
    # load data from files
    src_lang = data_cfg["src"]
    trg_lang = data_cfg["trg"]
    train_path = data_cfg["train"]
    dev_path = data_cfg["dev"]
    test_path = data_cfg.get("test", None)
    level = data_cfg["level"]
    lowercase = data_cfg["lowercase"]
    max_sent_length = data_cfg["max_sent_length"]

    tok_fun = lambda s: list(s) if level == "char" else s.split()

    src_field = data.Field(init_token=None,
                           eos_token=EOS_TOKEN,
                           pad_token=PAD_TOKEN,
                           tokenize=tok_fun,
                           batch_first=True,
                           lower=lowercase,
                           unk_token=UNK_TOKEN,
                           include_lengths=True)

    trg_field = data.Field(init_token=BOS_TOKEN,
                           eos_token=EOS_TOKEN,
                           pad_token=PAD_TOKEN,
                           tokenize=tok_fun,
                           unk_token=UNK_TOKEN,
                           batch_first=True,
                           lower=lowercase,
                           include_lengths=True)

    train_data = TranslationDataset(
        path=train_path,
        exts=("." + src_lang, "." + trg_lang),
        fields=(src_field, trg_field),
        filter_pred=lambda x: len(vars(x)['src']) <= max_sent_length and len(
            vars(x)['trg']) <= max_sent_length)

    src_max_size = data_cfg.get("src_voc_limit", sys.maxsize)
    src_min_freq = data_cfg.get("src_voc_min_freq", 1)
    trg_max_size = data_cfg.get("trg_voc_limit", sys.maxsize)
    trg_min_freq = data_cfg.get("trg_voc_min_freq", 1)

    src_vocab_file = data_cfg.get("src_vocab", None)
    trg_vocab_file = data_cfg.get("trg_vocab", None)

    src_vocab = build_vocab(field="src",
                            min_freq=src_min_freq,
                            max_size=src_max_size,
                            dataset=train_data,
                            vocab_file=src_vocab_file)
    trg_vocab = build_vocab(field="trg",
                            min_freq=trg_min_freq,
                            max_size=trg_max_size,
                            dataset=train_data,
                            vocab_file=trg_vocab_file)
    dev_data = TranslationDataset(path=dev_path,
                                  exts=("." + src_lang, "." + trg_lang),
                                  fields=(src_field, trg_field))
    test_data = None
    if test_path is not None:
        # check if target exists
        if os.path.isfile(test_path + "." + trg_lang):
            test_data = TranslationDataset(path=test_path,
                                           exts=("." + src_lang,
                                                 "." + trg_lang),
                                           fields=(src_field, trg_field))
        else:
            # no target is given -> create dataset from src only
            test_data = MonoDataset(path=test_path,
                                    ext="." + src_lang,
                                    field=src_field)
    src_field.vocab = src_vocab
    trg_field.vocab = trg_vocab
    return train_data, dev_data, test_data, src_vocab, trg_vocab
Exemplo n.º 30
0
    def __init__(self,
                 module_name,
                 train_bs,
                 eval_bs,
                 device,
                 vocab=None,
                 base_folder=None,
                 train_name=None,
                 eval_name=None,
                 x_ext=None,
                 y_ext=None,
                 tokens=None,
                 specials=None,
                 tokenizer=None,
                 sort_within_batch=None,
                 shuffle=None):

        self.module_name = module_name

        # split_chars = lambda x: list("".join(x.split()))
        split_chars = lambda x: list(x)  # keeps whitespaces

        if not tokenizer:
            tokenizer = split_chars

        # NOTE: on Jul-20-2020, removed fix_length=200 since it forces
        # all batches to be of size (batch_size, 200) which
        # really wastes GPU memory
        source = Field(tokenize=tokenizer,
                       init_token='<sos>',
                       eos_token='<eos>',
                       batch_first=True)

        target = Field(tokenize=tokenizer,
                       init_token='<sos>',
                       eos_token='<eos>',
                       batch_first=True)

        base_folder = os.path.expanduser(base_folder)

        folder = os.path.join(base_folder, module_name)

        # fix slashes
        folder = os.path.abspath(folder)

        print("loading FULL datasets from folder={}".format(folder))

        train_dataset, eval_dataset, _ = TranslationDataset.splits(
            path=folder,
            root=folder,
            exts=(x_ext, y_ext),
            fields=(source, target),
            train=train_name,
            validation=eval_name,
            test=eval_name)

        if vocab:
            print("Setting vocab to prebuilt file...")
            source.vocab = vocab
            target.vocab = vocab
        elif tokens:
            print("Building vocab from tokens...")
            #source.build_vocab(tokens, specials)
            counter = Counter(tokens)
            source.vocab = source.vocab_cls(counter, specials=specials)
            target.vocab = source.vocab
        else:
            print("Building vocab from TRAIN and EVAL datasets...")
            source.build_vocab(train_dataset, eval_dataset)
            target.vocab = source.vocab

        print("Creating iterators ...")
        do_shuffle = True if shuffle is None else shuffle
        train_iterator = Iterator(dataset=train_dataset,
                                  batch_size=train_bs,
                                  train=True,
                                  repeat=True,
                                  shuffle=do_shuffle,
                                  sort_within_batch=sort_within_batch,
                                  device=device)

        eval_iterator = Iterator(dataset=eval_dataset,
                                 batch_size=eval_bs,
                                 train=False,
                                 repeat=False,
                                 shuffle=False,
                                 sort_within_batch=sort_within_batch,
                                 device=device)

        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset

        self.train_iterator = train_iterator
        self.eval_iterator = eval_iterator

        self.source = source
        self.target = target