示例#1
0
def create_tokenizer_imbd(data_path, file_name, vocab_size):
    #df = pd.read_csv(os.path.join(data_path, file_name))
    tokenizer = CharBPETokenizer()
    tokenizer.train(
        os.path.join(data_path, file_name),
        vocab_size=vocab_size,
        min_frequency=2,
        show_progress=True,
        special_tokens=["[CLS]", "[PAD]", "[MASK]", "[UNK]", "[SEP]"])

    print("[CLS]: {}, [PAD]: {}, [MASK]: {}, [UNK]: {}, [SEP]: {}".format(
        str(tokenizer.token_to_id("[CLS]")),
        str(tokenizer.token_to_id("[PAD]")),
        str(tokenizer.token_to_id("[MASK]")),
        str(tokenizer.token_to_id("[UNK]")),
        str(tokenizer.token_to_id("[SEP]"))))

    tokenizer.save(data_path, "tokenizer")
示例#2
0
def create_tokenizer(data_path, vocab_size):

    tokenizer = CharBPETokenizer()
    tokenizer.train([
        os.path.join(data_path, file) for file in
        [f
         for f in os.listdir(data_path) if f.find("uncased_chunk") != -1][:20]
    ],
                    vocab_size=vocab_size,
                    min_frequency=2,
                    show_progress=True,
                    special_tokens=[
                        "[CLS]", "[PAD]", "[MASK]", "[UNK]", "[SEP]"
                    ])

    print("[CLS]: {}, [PAD]: {}, [MASK]: {}, [UNK]: {}, [SEP]: {}".format(
        str(tokenizer.token_to_id("[CLS]")),
        str(tokenizer.token_to_id("[PAD]")),
        str(tokenizer.token_to_id("[MASK]")),
        str(tokenizer.token_to_id("[UNK]")),
        str(tokenizer.token_to_id("[SEP]"))))

    tokenizer.save(data_path, "tokenizer")
示例#3
0
class BPETokenizer:
    def __init__(self, text_list, vocab_size, lazy=False):
        if not lazy:
            self.tokenizer = CharBPETokenizer()
            self.tokenizer.train(text_list,
                                 vocab_size=vocab_size,
                                 special_tokens=[PAD, BOS, EOS, "<unk>"])
            self.tokenizer.add_special_tokens([PAD, BOS, EOS])
        else:
            self.tokenizer = None

    def tokens_to_ids(self, tokens):
        return [self.tokenizer.token_to_id(t) for t in tokens]

    def ids_to_tokens(self, ids):
        return [self.tokenizer.id_to_token(i) for i in ids]

    def encode(self, text):
        encodes = self.tokenizer.encode(text)
        return encodes.ids

    def decode(self, ids, skip_special=True):
        return self.tokenizer.decode(ids, skip_special_tokens=skip_special)

    def save(self, path, file_name):
        self.tokenizer.save(path, file_name)

    @classmethod
    def load(cls, vocab, merges):
        tkz = cls(None, None, lazy=True)
        tkz.tokenizer = CharBPETokenizer(vocab, merges)
        tkz.tokenizer.add_special_tokens([PAD, BOS, EOS])
        return tkz

    def __len__(self):
        return self.tokenizer.get_vocab_size()
示例#4
0
文件: run_glue.py 项目: berbuf/asct
def launch(task_params, env_params, model_params,
           optim_params, data_params, trainer_params):
    main_params = {}

    # print params
    if (env_params['distributed'] == False or
        env_params['rank'] == 0):
        print('env_params:\t', env_params)
        print('model_params:\t', model_params)
        print('optim_params:\t', optim_params)
        print('data_params:\t', data_params)
        print('trainer_params:\t', trainer_params)

    # computation env
    set_up_env(env_params)
    device = env_params['device']

    logger = Logger()

    for task in task_params:
        print (task)

        task_config = task_params[task]
        model_params["block_size"] = task_config["block_size"]
        trainer_params["batch_size"] = task_config["batch_size"]

        print('task_params:\t', task_config)

        # data
        data_path = data_params["data_path"]
        tokenizer = CharBPETokenizer(join(data_path, "tokenizer-vocab.json"),
                                     join(data_path, "tokenizer-merges.txt"),
                                     unk_token="[UNK]")
        train_data, val_data, num_labels = load_glue(tokenizer, task, task_config)

        # model
        pad_idx = tokenizer.token_to_id("[PAD]")
        model = GenDisc(vocab_size=data_params['vocab_size'],
                        batch_size=trainer_params["batch_size"],
                        model_params=model_params, pad_idx=pad_idx)
        model = model.to(device)

        # optimizer, scheduler, logger and resume from checkpoint
        optim_params = task_config["optim_params"]
        optimizer, scheduler = get_optimizer_and_scheduler(
            model=model, optim_params=optim_params)

        # reload checkpoint
        main_params["iter_init"] = load_checkpoint(
            trainer_params['checkpoint_path'], trainer_params['last_iter'],
            model, optimizer, scheduler,
            logger, parallel=False)

        asct = AsctSequenceClassification(task_config, model_params,
                                          model, num_labels)
        asct = asct.to(device)

        # store main params
        main_params["model"] = asct
        main_params["device"] = device
        main_params["optimizer"] = optimizer
        main_params["scheduler"] = scheduler
        main_params["logger"] = logger

        train_glue(train_data, val_data, main_params, trainer_params,
                   env_params, task_config, task)
        return
示例#5
0
class EngGerNewstest(Dataset):
    """
    The newstest 2014 dataset used for testing
    """
    def __init__(self,
                 data_folder,
                 rank=0,
                 val_set=False,
                 world_size=1,
                 seed=0,
                 eng_to_ger=True,
                 vocab_size=37000,
                 MASK="<MASK>",
                 START="<START>",
                 STOP="<STOP>",
                 exp_name="",
                 max_context=None,
                 batch_size=128,
                 val_size=30000,
                 **kwargs):
        """
        rank: int
            the rank in the distributed training
        val_set: bool
            if true, this dataset is created as the validation set
        world_size: int
            the number of processes if using distributed training
        seed: int
            random seed
        data_folder: str
            the path to the folder that should contain a `train.en` and
            a `train.de` file.
        eng_to_ger: bool
            if true, the x values are returned as english ids and the
            y values are german ids. If false, then visa-versa
        vocab_size: int
            the number of encodings for the byte-pair encoding scheme
        MASK: str
            the mask token
        START: str
            the start token
        STOP: str
            the stop token
        exp_name: str
            name of the experiment
        max_context: int
            the maximum sequence length
        val_size: int
            the number of samples to be set aside for validation
        """
        self.rank = rank
        print("rank:", self.rank)
        self.world_size = world_size
        self.val_set = val_set
        self.val_size = val_size
        self.batch_size = batch_size
        self.data_folder = os.path.expanduser(data_folder)
        self.en_path = os.path.join(data_folder, "newstest2014.en")
        self.de_path = os.path.join(data_folder, "newstest2014.de")
        self.eng_to_ger = eng_to_ger
        self.vocab_size = vocab_size
        self.MASK = MASK
        self.START = START
        self.STOP = STOP
        self.max_context = max_context
        self.en_tok_path = os.path.join(self.data_folder, "en_tokenizer")
        self.de_tok_path = os.path.join(self.data_folder, "de_tokenizer")
        self.en_arr_path = os.path.join(self.data_folder, "en_bcolz")
        self.de_arr_path = os.path.join(self.data_folder, "de_bcolz")
        self.en_lens_path = os.path.join(self.data_folder, "en_bcolz_lens")
        self.de_lens_path = os.path.join(self.data_folder, "de_bcolz_lens")

        # Train tokenizers
        if rank == 0: print("Tokenizing english..")
        self.en_tokenizer = CharBPETokenizer()
        if os.path.exists(self.en_tok_path):  # Load trained tokenizer
            if rank == 0:
                print("loading from pretrained tokenizer", self.en_tok_path)
            self.en_tokenizer = ml_utils.datas.load_tokenizer(
                self.en_tokenizer, self.en_tok_path)
        else:
            self.en_tokenizer.train([self.en_path], vocab_size=vocab_size)
            os.mkdir(self.en_tok_path)
            self.en_tokenizer.save_model(self.en_tok_path)
        self.en_tokenizer.add_special_tokens([self.MASK])
        self.en_tokenizer.add_tokens([self.START])
        self.en_tokenizer.add_tokens([self.STOP])
        self.en_mask_idx = self.en_tokenizer.token_to_id(self.MASK)
        self.en_start_idx = self.en_tokenizer.token_to_id(self.START)
        self.en_stop_idx = self.en_tokenizer.token_to_id(self.STOP)

        if rank == 0: print("Tokenizing german..")
        self.de_tokenizer = CharBPETokenizer()
        if os.path.exists(self.de_tok_path):  # Load trained tokenizer
            if rank == 0:
                print("loading from pretrained tokenizer", self.de_tok_path)
            self.de_tokenizer = ml_utils.datas.load_tokenizer(
                self.de_tokenizer, self.de_tok_path)
        else:
            self.de_tokenizer.train([self.de_path], vocab_size=vocab_size)
            os.mkdir(self.de_tok_path)
            self.de_tokenizer.save_model(self.de_tok_path)
        self.de_tokenizer.add_special_tokens([self.MASK])
        self.de_tokenizer.add_tokens([self.START])
        self.de_tokenizer.add_tokens([self.STOP])
        self.de_mask_idx = self.de_tokenizer.token_to_id(self.MASK)
        self.de_start_idx = self.de_tokenizer.token_to_id(self.START)
        self.de_stop_idx = self.de_tokenizer.token_to_id(self.STOP)

        # Get English sentence lists
        if rank == 0: print("Making english idxs")
        self.en_max_len = 0
        self.en_idxs = []
        self.en_lens = []
        with open(self.en_path, 'r') as f:
            for i, l in tqdm(enumerate(f.readlines())):
                l = l.strip()
                if len(l) > 0:
                    output = self.en_tokenizer.encode(l)
                    ids = [self.en_start_idx]+list(output.ids)\
                                             +[self.en_stop_idx]
                    self.en_idxs.append(ids)
                    self.en_lens.append(len(ids))
                    if len(ids) > self.en_max_len:
                        self.en_max_len = len(ids)
                if exp_name == "test" and i > 100:
                    break
        mask = [self.en_mask_idx for i in range(self.en_max_len)]
        l = 0
        if rank == 0: print("Padding english idxs")
        for i in tqdm(range(len(self.en_idxs))):
            diff = self.en_max_len - len(self.en_idxs[i])
            self.en_idxs[i] = self.en_idxs[i] + mask[:diff]

        # Get German Sentence Lists
        if rank == 0: print("Making german idxs")
        self.de_max_len = 0
        self.de_idxs = []
        self.de_lens = []
        with open(self.de_path, 'r') as f:
            for i, l in tqdm(enumerate(f.readlines())):
                l = l.strip()
                if len(l) > 0:
                    output = self.de_tokenizer.encode(l)
                    ids = [self.de_start_idx]+list(output.ids)\
                                             +[self.de_stop_idx]
                    self.de_idxs.append(ids)
                    self.de_lens.append(len(ids))
                    if len(ids) > self.de_max_len:
                        self.de_max_len = len(ids)
                if exp_name == "test" and i > 100:
                    break
        mask = [self.de_mask_idx for i in range(self.de_max_len)]
        if rank == 0: print("Padding german idxs")
        for i in tqdm(range(len(self.de_idxs))):
            diff = self.de_max_len - len(self.de_idxs[i])
            self.de_idxs[i] = self.de_idxs[i] + mask[:diff]

        if rank == 0: print("Converting to numpy arrays")
        if self.eng_to_ger:
            self.X = np.asarray(self.en_idxs)
            self.X_lens = np.asarray(self.en_lens)
            self.X_tokenizer = self.en_tokenizer
            self.X_mask_idx = self.en_mask_idx
            self.X_start_idx = self.en_start_idx
            self.X_stop_idx = self.en_stop_idx
            self.X_max_len = self.en_max_len

            self.Y = np.asarray(self.de_idxs)
            self.Y_lens = np.asarray(self.de_lens)
            self.Y_tokenizer = self.de_tokenizer
            self.Y_mask_idx = self.de_mask_idx
            self.Y_start_idx = self.de_start_idx
            self.Y_stop_idx = self.de_stop_idx
            self.Y_max_len = self.de_max_len
        else:
            self.X = np.asarray(self.de_idxs)
            self.X_lens = np.asarray(self.de_lens)
            self.X_tokenizer = self.de_tokenizer
            self.X_mask_idx = self.de_mask_idx
            self.X_start_idx = self.de_start_idx
            self.X_stop_idx = self.de_stop_idx
            self.X_max_len = self.de_max_len

            self.Y = np.asarray(self.en_idxs)
            self.Y_lens = np.asarray(self.en_lens)
            self.Y_tokenizer = self.en_tokenizer
            self.Y_mask_idx = self.en_mask_idx
            self.Y_start_idx = self.en_start_idx
            self.Y_stop_idx = self.en_stop_idx
            self.Y_max_len = self.en_max_len

    def __len__(self):
        return len(self.en_idxs)

    #def __getitem__(self,i,l=None):
    #    if l is None:
    #        l = self.X_lens[int(i)]
    #    idxs = np.zeros(1)
    #    margin = 5
    #    while idxs.sum()<25 and margin < 400:
    #        min_l = l-margin
    #        max_l = l+margin
    #        idxs = (self.X_lens>min_l)&(self.X_lens<max_l)
    #        margin += 5
    #    max_l = min(np.max(self.X_lens[idxs]),self.max_context)
    #    if max_l <   50 : batch_size = self.batch_size
    #    elif max_l < 70: batch_size = self.batch_size//2
    #    elif max_l < 100: batch_size = self.batch_size//4
    #    elif max_l < 120: batch_size = self.batch_size//8
    #    elif max_l < 140: batch_size = self.batch_size//16
    #    elif max_l < 160: batch_size = self.batch_size//32
    #    else: batch_size = self.batch_size//64
    #    batch_size = max(16,batch_size)
    #    perm = np.random.permutation(idxs.sum())[:batch_size]
    #    max_l = np.max(self.X_lens[idxs][perm])
    #    x = np.asarray(self.X[idxs][perm,:max_l])
    #    max_l = np.max(self.Y_lens[idxs][perm])
    #    y = np.asarray(self.Y[idxs][perm,:max_l])
    #    return torch.LongTensor(x), torch.LongTensor(y)

    def __getitem__(self, idx):
        return torch.LongTensor(self.X[idx]), torch.LongTensor(self.Y[idx])

    def get_largest_batch(self, size_num):
        l = 10
        if size_num == 1:
            l = 25
        elif size_num == 2:
            l = 400
        elif size_num == 3:
            l = 130
        elif size_num == 4:
            l = 75
        elif size_num == 5:
            l = 44
        elif size_num == 6:
            l = 94
        elif size_num == 7:
            l = 200
        elif size_num == 8:
            l = 300
        return self.__getitem__(0, l)

    def X_idxs2tokens(self, idxs):
        """
        idxs: LongTensor (N,)
            converts an array of tokens to a sentence
        """
        return self.X_tokenizer.decode(idxs)

    def Y_idxs2tokens(self, idxs):
        """
        idxs: LongTensor (N,)
            converts an array of tokens to a sentence
        """
        return self.Y_tokenizer.decode(idxs)
示例#6
0
class EngGerDataset(Dataset):
    """
    Can be english to german or german to english.
    """
    def __init__(self,
                 data_folder,
                 rank=0,
                 val_set=False,
                 world_size=1,
                 seed=0,
                 eng_to_ger=True,
                 vocab_size=37000,
                 MASK="<MASK>",
                 START="<START>",
                 STOP="<STOP>",
                 exp_name="",
                 max_context=None,
                 batch_size=128,
                 val_size=30000,
                 **kwargs):
        """
        rank: int
            the rank in the distributed training
        val_set: bool
            if true, this dataset is created as the validation set
        world_size: int
            the number of processes if using distributed training
        seed: int
            random seed
        data_folder: str
            the path to the folder that should contain a `train.en` and
            a `train.de` file.
        eng_to_ger: bool
            if true, the x values are returned as english ids and the
            y values are german ids. If false, then visa-versa
        vocab_size: int
            the number of encodings for the byte-pair encoding scheme
        MASK: str
            the mask token
        START: str
            the start token
        STOP: str
            the stop token
        exp_name: str
            name of the experiment
        max_context: int
            the maximum sequence length
        val_size: int
            the number of samples to be set aside for validation
        """
        self.rank = rank
        print("rank:", self.rank)
        self.world_size = world_size
        self.val_set = val_set
        self.val_size = val_size
        self.batch_size = batch_size
        self.data_folder = os.path.expanduser(data_folder)
        self.en_path = os.path.join(data_folder, "train.en")
        self.de_path = os.path.join(data_folder, "train.de")
        self.eng_to_ger = eng_to_ger
        self.vocab_size = vocab_size
        self.MASK = MASK
        self.START = START
        self.STOP = STOP
        self.max_context = max_context
        self.en_tok_path = os.path.join(self.data_folder, "en_tokenizer")
        self.de_tok_path = os.path.join(self.data_folder, "de_tokenizer")
        self.en_arr_path = os.path.join(self.data_folder, "en_bcolz")
        self.de_arr_path = os.path.join(self.data_folder, "de_bcolz")
        self.en_lens_path = os.path.join(self.data_folder, "en_bcolz_lens")
        self.de_lens_path = os.path.join(self.data_folder, "de_bcolz_lens")

        # Train tokenizers
        if rank == 0: print("Tokenizing english..")
        self.en_tokenizer = CharBPETokenizer()
        if os.path.exists(self.en_tok_path):  # Load trained tokenizer
            if rank == 0:
                print("loading from pretrained tokenizer", self.en_tok_path)
            self.en_tokenizer = ml_utils.datas.load_tokenizer(
                self.en_tokenizer, self.en_tok_path)
        else:
            self.en_tokenizer.train([self.en_path], vocab_size=vocab_size)
            os.mkdir(self.en_tok_path)
            self.en_tokenizer.save_model(self.en_tok_path)
        self.en_tokenizer.add_special_tokens([self.MASK])
        self.en_tokenizer.add_tokens([self.START])
        self.en_tokenizer.add_tokens([self.STOP])
        self.en_mask_idx = self.en_tokenizer.token_to_id(self.MASK)
        self.en_start_idx = self.en_tokenizer.token_to_id(self.START)
        self.en_stop_idx = self.en_tokenizer.token_to_id(self.STOP)

        if rank == 0: print("Tokenizing german..")
        self.de_tokenizer = CharBPETokenizer()
        if os.path.exists(self.de_tok_path):  # Load trained tokenizer
            if rank == 0:
                print("loading from pretrained tokenizer", self.de_tok_path)
            self.de_tokenizer = ml_utils.datas.load_tokenizer(
                self.de_tokenizer, self.de_tok_path)
        else:
            self.de_tokenizer.train([self.de_path], vocab_size=vocab_size)
            os.mkdir(self.de_tok_path)
            self.de_tokenizer.save_model(self.de_tok_path)
        self.de_tokenizer.add_special_tokens([self.MASK])
        self.de_tokenizer.add_tokens([self.START])
        self.de_tokenizer.add_tokens([self.STOP])
        self.de_mask_idx = self.de_tokenizer.token_to_id(self.MASK)
        self.de_start_idx = self.de_tokenizer.token_to_id(self.START)
        self.de_stop_idx = self.de_tokenizer.token_to_id(self.STOP)

        # Get English sentence lists
        if rank == 0: print("Making english idxs")
        if os.path.exists(self.en_arr_path):
            if rank == 0: print("loading from bcolz", self.en_arr_path)
            self.en_idxs = bcolz.carray(rootdir=self.en_arr_path)
            self.en_lens = bcolz.carray(rootdir=self.en_lens_path)
            self.en_max_len = self.en_idxs.shape[-1]
            if exp_name == "test":
                self.val_size = 250
                self.en_idxs = self.en_idxs[:1000]
                self.en_lens = self.en_lens[:1000]
            if self.world_size > 1:
                with temp_seed(seed - rank):
                    sample_perm = np.random.permutation(len(self.en_idxs))
                if not self.val_set:
                    n_samps = (len(self.en_idxs) - self.val_size)
                    n_samps = n_samps // self.world_size
                    indices = sample_perm[rank * n_samps:(rank + 1) * n_samps]
                else:
                    indices = sample_perm[-self.val_size:]
                try:
                    if rank == 0:
                        print("splitting dataset.. ", end="")
                        starttime = time.time()
                    self.en_idxs = self.en_idxs[indices]
                    self.en_lens = self.en_lens[indices]
                    if rank == 0: print("duration:", time.time() - starttime)
                except:
                    temp_idxs = []
                    temp_lens = []
                    if rank == 0:
                        print("Collecting data")
                        rnge = tqdm(indices)
                    else:
                        rnge = indices
                    for i in rnge:
                        temp_idxs.append(self.en_idxs[i])
                        temp_lens.append(self.en_lens[i])
                    self.en_idxs = np.asarray(temp_idxs)
                    self.en_lens = np.asarray(temp_lens)
                    if rank == 0: print("duration:", time.time() - starttime)
        elif world_size == 1:
            self.en_max_len = 0
            self.en_idxs = []
            self.en_lens = []
            with open(self.en_path, 'r') as f:
                for i, l in tqdm(enumerate(f.readlines())):
                    l = l.strip()
                    if len(l) > 0:
                        output = self.en_tokenizer.encode(l)
                        ids = [self.en_start_idx]+list(output.ids)\
                                                 +[self.en_stop_idx]
                        self.en_idxs.append(ids)
                        self.en_lens.append(len(ids))
                        if len(ids) > self.en_max_len:
                            self.en_max_len = len(ids)
                    if exp_name == "test" and i > 100:
                        break
            mask = [self.en_mask_idx for i in range(self.en_max_len)]
            l = 0
            if rank == 0: print("Padding english idxs")
            for i in tqdm(range(len(self.en_idxs))):
                diff = self.en_max_len - len(self.en_idxs[i])
                self.en_idxs[i] = self.en_idxs[i] + mask[:diff]
            if rank == 0: print("Saving to bcolz")
            self.en_idxs = bcolz.carray(self.en_idxs,
                                        rootdir=self.en_arr_path,
                                        dtype="int32")
            self.en_idxs.flush()
            self.en_lens = bcolz.carray(self.en_lens,
                                        rootdir=self.en_lens_path,
                                        dtype="int32")
            self.en_lens.flush()
        else:
            print("Make dataset without using multi-processing!!")
            assert False
        if self.en_max_len > max_context:
            if rank == 0:
                print("Truncating context from", self.en_max_len, "to",
                      self.max_context)
            self.en_max_len = self.max_context

        # Get German Sentence Lists
        if rank == 0: print("Making german idxs")
        if os.path.exists(self.de_arr_path):
            if rank == 0: print("loading from bcolz", self.de_arr_path)
            self.de_idxs = bcolz.carray(rootdir=self.de_arr_path)
            self.de_lens = bcolz.carray(rootdir=self.de_lens_path)
            self.de_max_len = self.de_idxs.shape[-1]
            if exp_name == "test":
                self.val_size = 250
                self.en_idxs = self.en_idxs[:1000]
                self.en_lens = self.en_lens[:1000]
            if self.world_size > 1:
                try:
                    if rank == 0:
                        print("splitting dataset.. ", end="")
                        starttime = time.time()
                    self.de_idxs = self.de_idxs[indices]
                    self.de_lens = self.de_lens[indices]
                    if rank == 0: print("duration:", time.time() - starttime)
                except:
                    temp_idxs = []
                    temp_lens = []
                    try:
                        if rank == 0: print("Collecting data")
                        for i in rnge:
                            temp_idxs.append(self.de_idxs[i])
                            temp_lens.append(self.de_lens[i])
                    except Exception as e:
                        print("Likely error caused by bcolz existing "+\
                                               "for en but not de data")
                        print(e)
                        assert False
                    self.de_idxs = np.asarray(temp_idxs)
                    self.de_lens = np.asarray(temp_lens)
                    if rank == 0: print("duration:", time.time() - starttime)
        else:
            self.de_max_len = 0
            self.de_idxs = []
            self.de_lens = []
            with open(self.de_path, 'r') as f:
                for i, l in tqdm(enumerate(f.readlines())):
                    l = l.strip()
                    if len(l) > 0:
                        output = self.de_tokenizer.encode(l)
                        ids = [self.de_start_idx]+list(output.ids)\
                                                 +[self.de_stop_idx]
                        self.de_idxs.append(ids)
                        self.de_lens.append(len(ids))
                        if len(ids) > self.de_max_len:
                            self.de_max_len = len(ids)
                    if exp_name == "test" and i > 100:
                        break
            mask = [self.de_mask_idx for i in range(self.de_max_len)]
            if rank == 0: print("Padding german idxs")
            for i in tqdm(range(len(self.de_idxs))):
                diff = self.de_max_len - len(self.de_idxs[i])
                self.de_idxs[i] = self.de_idxs[i] + mask[:diff]
            if rank == 0: print("Saving to bcolz")
            self.de_idxs = bcolz.carray(self.de_idxs,
                                        rootdir=self.de_arr_path,
                                        dtype="int32")
            self.de_idxs.flush()
            self.de_lens = bcolz.carray(self.de_lens,
                                        rootdir=self.de_lens_path,
                                        dtype="int32")
            self.de_lens.flush()
        if self.de_max_len > max_context:
            if rank == 0:
                print("Truncating context from", self.de_max_len, "to",
                      self.max_context)
            self.de_max_len = self.max_context

        if rank == 0: print("Converting to numpy arrays")
        if self.eng_to_ger:
            self.X = np.asarray(self.en_idxs)
            self.X_lens = np.asarray(self.en_lens)
            self.X_tokenizer = self.en_tokenizer
            self.X_mask_idx = self.en_mask_idx
            self.X_start_idx = self.en_start_idx
            self.X_stop_idx = self.en_stop_idx
            self.X_max_len = self.en_max_len

            self.Y = np.asarray(self.de_idxs)
            self.Y_lens = np.asarray(self.de_lens)
            self.Y_tokenizer = self.de_tokenizer
            self.Y_mask_idx = self.de_mask_idx
            self.Y_start_idx = self.de_start_idx
            self.Y_stop_idx = self.de_stop_idx
            self.Y_max_len = self.de_max_len
        else:
            self.X = np.asarray(self.de_idxs)
            self.X_lens = np.asarray(self.de_lens)
            self.X_tokenizer = self.de_tokenizer
            self.X_mask_idx = self.de_mask_idx
            self.X_start_idx = self.de_start_idx
            self.X_stop_idx = self.de_stop_idx
            self.X_max_len = self.de_max_len

            self.Y = np.asarray(self.en_idxs)
            self.Y_lens = np.asarray(self.en_lens)
            self.Y_tokenizer = self.en_tokenizer
            self.Y_mask_idx = self.en_mask_idx
            self.Y_start_idx = self.en_start_idx
            self.Y_stop_idx = self.en_stop_idx
            self.Y_max_len = self.en_max_len

    def __len__(self):
        return len(self.en_idxs)

    def __getitem__(self, i, l=None):
        if l is None:
            l = self.X_lens[int(i)]
        idxs = np.zeros(1)
        margin = 5
        while idxs.sum() < 25 and margin < 400:
            min_l = l - margin
            max_l = l + margin
            idxs = (self.X_lens > min_l) & (self.X_lens < max_l)
            margin += 5
        max_l = min(np.max(self.X_lens[idxs]), self.max_context)
        if max_l < 50: batch_size = self.batch_size
        elif max_l < 70: batch_size = self.batch_size // 2
        elif max_l < 100: batch_size = self.batch_size // 4
        elif max_l < 120: batch_size = self.batch_size // 8
        elif max_l < 140: batch_size = self.batch_size // 16
        elif max_l < 160: batch_size = self.batch_size // 32
        else: batch_size = self.batch_size // 64
        batch_size = max(16, batch_size)
        perm = np.random.permutation(idxs.sum())[:batch_size]
        max_l = np.max(self.X_lens[idxs][perm])
        x = np.asarray(self.X[idxs][perm, :max_l])
        max_l = np.max(self.Y_lens[idxs][perm])
        y = np.asarray(self.Y[idxs][perm, :max_l])
        return torch.LongTensor(x), torch.LongTensor(y)

    def get_largest_batch(self, size_num):
        l = 10
        if size_num == 1:
            l = 25
        elif size_num == 2:
            l = 400
        elif size_num == 3:
            l = 130
        elif size_num == 4:
            l = 75
        elif size_num == 5:
            l = 44
        elif size_num == 6:
            l = 94
        elif size_num == 7:
            l = 200
        elif size_num == 8:
            l = 300
        return self.__getitem__(0, l)

    def X_idxs2tokens(self, idxs):
        """
        idxs: LongTensor (N,)
            converts an array of tokens to a sentence
        """
        return self.X_tokenizer.decode(idxs)

    def Y_idxs2tokens(self, idxs):
        """
        idxs: LongTensor (N,)
            converts an array of tokens to a sentence
        """
        return self.Y_tokenizer.decode(idxs)