def write_batches_to_tarfiles(
    args,
    src_fname,
    tgt_fname,
    num_tokens,
    encoder_tokenizer,
    decoder_tokenizer,
    num_files_in_tar,
    tar_file_ptr,
    tar_file_ctr,
    global_batch_ctr,
):
    """
    Writes current fragment of the overall parallel corpus to tarfiles by:
    (1) Creating a minibatches using a TranslationDataset object.
    (2) Writing each minibatch to a pickle file.
    (3) Adding pickle files to a tarfile until it reaches args.num_batches_per_tarfile.
    """

    dataset = TranslationDataset(
        dataset_src=src_fname,
        dataset_tgt=tgt_fname,
        tokens_in_batch=num_tokens,
        clean=args.clean,
        max_seq_length=args.max_seq_length,
        min_seq_length=args.min_seq_length,
        max_seq_length_diff=args.max_seq_length,
        max_seq_length_ratio=args.max_seq_length,
        cache_ids=False,
        cache_data_per_node=False,
        use_cache=False,
    )
    dataset.batchify(encoder_tokenizer, decoder_tokenizer)

    for _, batch in dataset.batches.items():
        global_batch_ctr += 1
        pickle.dump(batch, open(os.path.join(args.out_dir, 'batch-%d.pkl' % (global_batch_ctr)), 'wb'))

        if num_files_in_tar == args.num_batches_per_tarfile:
            tar_file_ctr += 1
            tar_file_ptr.close()
            tar_file_ptr = tarfile.open(
                os.path.join(args.out_dir, 'batches.tokens.%d.%d.tar' % (num_tokens, tar_file_ctr)), 'w'
            )
            num_files_in_tar = 0

        tar_file_ptr.add(os.path.join(args.out_dir, 'batch-%d.pkl' % (global_batch_ctr)))
        num_files_in_tar += 1
        os.remove(os.path.join(args.out_dir, 'batch-%d.pkl' % (global_batch_ctr)))
    return tar_file_ptr, global_batch_ctr, num_files_in_tar, tar_file_ctr
Exemple #2
0
 def _setup_dataloader_from_config(self, cfg: DictConfig):
     if cfg.get("load_from_cached_dataset", False):
         logging.info('Loading from cached dataset %s' %
                      (cfg.src_file_name))
         if cfg.src_file_name != cfg.tgt_file_name:
             raise ValueError(
                 "src must be equal to target for cached dataset")
         dataset = pickle.load(open(cfg.src_file_name, 'rb'))
         dataset.reverse_lang_direction = cfg.get("reverse_lang_direction",
                                                  False)
     else:
         dataset = TranslationDataset(
             dataset_src=str(Path(cfg.src_file_name).expanduser()),
             dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()),
             tokens_in_batch=cfg.tokens_in_batch,
             clean=cfg.get("clean", False),
             max_seq_length=cfg.get("max_seq_length", 512),
             min_seq_length=cfg.get("min_seq_length", 1),
             max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
             max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
             cache_ids=cfg.get("cache_ids", False),
             cache_data_per_node=cfg.get("cache_data_per_node", False),
             use_cache=cfg.get("use_cache", False),
             reverse_lang_direction=cfg.get("reverse_lang_direction",
                                            False),
         )
         dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)
     if cfg.shuffle:
         sampler = pt_data.RandomSampler(dataset)
     else:
         sampler = pt_data.SequentialSampler(dataset)
     return torch.utils.data.DataLoader(
         dataset=dataset,
         batch_size=1,
         sampler=sampler,
         num_workers=cfg.get("num_workers", 2),
         pin_memory=cfg.get("pin_memory", False),
         drop_last=cfg.get("drop_last", False),
     )
Exemple #3
0
    def _setup_eval_dataloader_from_config(self, cfg: DictConfig):
        src_file_name = cfg.get('src_file_name')
        tgt_file_name = cfg.get('tgt_file_name')

        if src_file_name is None or tgt_file_name is None:
            raise ValueError(
                'Validation dataloader needs both cfg.src_file_name and cfg.tgt_file_name to not be None.'
            )
        else:
            # convert src_file_name and tgt_file_name to list of strings
            if isinstance(src_file_name, str):
                src_file_list = [src_file_name]
            elif isinstance(src_file_name, ListConfig):
                src_file_list = src_file_name
            else:
                raise ValueError(
                    "cfg.src_file_name must be string or list of strings")
            if isinstance(tgt_file_name, str):
                tgt_file_list = [tgt_file_name]
            elif isinstance(tgt_file_name, ListConfig):
                tgt_file_list = tgt_file_name
            else:
                raise ValueError(
                    "cfg.tgt_file_name must be string or list of strings")
        if len(src_file_list) != len(tgt_file_list):
            raise ValueError(
                'The same number of filepaths must be passed in for source and target validation.'
            )

        dataloaders = []
        prepend_idx = 0
        for idx, src_file in enumerate(src_file_list):
            if self.multilingual:
                prepend_idx = idx
            dataset = TranslationDataset(
                dataset_src=str(Path(src_file).expanduser()),
                dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()),
                tokens_in_batch=cfg.tokens_in_batch,
                clean=cfg.get("clean", False),
                max_seq_length=cfg.get("max_seq_length", 512),
                min_seq_length=cfg.get("min_seq_length", 1),
                max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
                max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
                cache_ids=cfg.get("cache_ids", False),
                cache_data_per_node=cfg.get("cache_data_per_node", False),
                use_cache=cfg.get("use_cache", False),
                reverse_lang_direction=cfg.get("reverse_lang_direction",
                                               False),
                prepend_id=self.multilingual_ids[prepend_idx]
                if self.multilingual else None,
            )
            dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)

            if cfg.shuffle:
                sampler = pt_data.RandomSampler(dataset)
            else:
                sampler = pt_data.SequentialSampler(dataset)

            dataloader = torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                sampler=sampler,
                num_workers=cfg.get("num_workers", 2),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )
            dataloaders.append(dataloader)

        return dataloaders
Exemple #4
0
 for num_tokens in tokens_in_batch:
     dataset = TranslationDataset(
         dataset_src=str(Path(args.src_fname).expanduser()),
         dataset_tgt=str(Path(args.tgt_fname).expanduser()),
         tokens_in_batch=num_tokens,
         clean=args.clean,
         max_seq_length=args.max_seq_length,
         min_seq_length=args.min_seq_length,
         max_seq_length_diff=args.max_seq_length,
         max_seq_length_ratio=args.max_seq_length,
         cache_ids=False,
         cache_data_per_node=False,
         use_cache=False,
     )
     print('Batchifying ...')
     dataset.batchify(encoder_tokenizer, decoder_tokenizer)
     start = time.time()
     pickle.dump(
         dataset,
         open(
             os.path.join(args.out_dir,
                          'batches.tokens.%d.pkl' % (num_tokens)), 'wb'))
     end = time.time()
     print('Took %f time to pickle' % (end - start))
     start = time.time()
     dataset = pickle.load(
         open(
             os.path.join(args.out_dir,
                          'batches.tokens.%d.pkl' % (num_tokens)), 'rb'))
     end = time.time()
     print('Took %f time to unpickle' % (end - start))