예제 #1
0
def write_batches_to_tarfiles(
    args,
    src_fname,
    tgt_fname,
    num_tokens,
    encoder_tokenizer,
    decoder_tokenizer,
    num_files_in_tar,
    tar_file_ptr,
    tar_file_ctr,
    global_batch_ctr,
):
    """
    Writes current fragment of the overall parallel corpus to tarfiles by:
    (1) Creating a minibatches using a TranslationDataset object.
    (2) Writing each minibatch to a pickle file.
    (3) Adding pickle files to a tarfile until it reaches args.num_batches_per_tarfile.
    """

    dataset = TranslationDataset(
        dataset_src=src_fname,
        dataset_tgt=tgt_fname,
        tokens_in_batch=num_tokens,
        clean=args.clean,
        max_seq_length=args.max_seq_length,
        min_seq_length=args.min_seq_length,
        max_seq_length_diff=args.max_seq_length,
        max_seq_length_ratio=args.max_seq_length,
        cache_ids=False,
        cache_data_per_node=False,
        use_cache=False,
    )
    dataset.batchify(encoder_tokenizer, decoder_tokenizer)

    for _, batch in dataset.batches.items():
        global_batch_ctr += 1
        pickle.dump(batch, open(os.path.join(args.out_dir, 'batch-%d.pkl' % (global_batch_ctr)), 'wb'))

        if num_files_in_tar == args.num_batches_per_tarfile:
            tar_file_ctr += 1
            tar_file_ptr.close()
            tar_file_ptr = tarfile.open(
                os.path.join(args.out_dir, 'batches.tokens.%d.%d.tar' % (num_tokens, tar_file_ctr)), 'w'
            )
            num_files_in_tar = 0

        tar_file_ptr.add(os.path.join(args.out_dir, 'batch-%d.pkl' % (global_batch_ctr)))
        num_files_in_tar += 1
        os.remove(os.path.join(args.out_dir, 'batch-%d.pkl' % (global_batch_ctr)))
    return tar_file_ptr, global_batch_ctr, num_files_in_tar, tar_file_ctr
예제 #2
0
 def _setup_dataloader_from_config(self, cfg: DictConfig):
     if cfg.get("load_from_cached_dataset", False):
         logging.info('Loading from cached dataset %s' % (cfg.src_file_name))
         if cfg.src_file_name != cfg.tgt_file_name:
             raise ValueError("src must be equal to target for cached dataset")
         dataset = pickle.load(open(cfg.src_file_name, 'rb'))
         dataset.reverse_lang_direction = cfg.get("reverse_lang_direction", False)
     elif cfg.get("load_from_tarred_dataset", False):
         logging.info('Loading from tarred dataset %s' % (cfg.src_file_name))
         if cfg.src_file_name != cfg.tgt_file_name:
             raise ValueError("src must be equal to target for tarred dataset")
         if cfg.get("metadata_path", None) is None:
             raise FileNotFoundError("Could not find metadata path in config")
         dataset = TarredTranslationDataset(
             text_tar_filepaths=cfg.src_file_name,
             metadata_path=cfg.metadata_path,
             encoder_tokenizer=self.encoder_tokenizer,
             decoder_tokenizer=self.decoder_tokenizer,
             shuffle_n=cfg.get("tar_shuffle_n", 100),
             shard_strategy=cfg.get("shard_strategy", "scatter"),
             global_rank=self.global_rank,
             world_size=self.world_size,
             reverse_lang_direction=cfg.get("reverse_lang_direction", False),
         )
         return torch.utils.data.DataLoader(
             dataset=dataset,
             batch_size=1,
             num_workers=cfg.get("num_workers", 2),
             pin_memory=cfg.get("pin_memory", False),
             drop_last=cfg.get("drop_last", False),
         )
     else:
         dataset = TranslationDataset(
             dataset_src=str(Path(cfg.src_file_name).expanduser()),
             dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()),
             tokens_in_batch=cfg.tokens_in_batch,
             clean=cfg.get("clean", False),
             max_seq_length=cfg.get("max_seq_length", 512),
             min_seq_length=cfg.get("min_seq_length", 1),
             max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
             max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
             cache_ids=cfg.get("cache_ids", False),
             cache_data_per_node=cfg.get("cache_data_per_node", False),
             use_cache=cfg.get("use_cache", False),
             reverse_lang_direction=cfg.get("reverse_lang_direction", False),
         )
         dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)
     if cfg.shuffle:
         sampler = pt_data.RandomSampler(dataset)
     else:
         sampler = pt_data.SequentialSampler(dataset)
     return torch.utils.data.DataLoader(
         dataset=dataset,
         batch_size=1,
         sampler=sampler,
         num_workers=cfg.get("num_workers", 2),
         pin_memory=cfg.get("pin_memory", False),
         drop_last=cfg.get("drop_last", False),
     )
예제 #3
0
 def _setup_dataloader_from_config(self, cfg: DictConfig):
     if cfg.get("load_from_cached_dataset", False):
         logging.info('Loading from cached dataset %s' %
                      (cfg.src_file_name))
         if cfg.src_file_name != cfg.tgt_file_name:
             raise ValueError(
                 "src must be equal to target for cached dataset")
         dataset = pickle.load(open(cfg.src_file_name, 'rb'))
         dataset.reverse_lang_direction = cfg.get("reverse_lang_direction",
                                                  False)
     else:
         dataset = TranslationDataset(
             dataset_src=str(Path(cfg.src_file_name).expanduser()),
             dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()),
             tokens_in_batch=cfg.tokens_in_batch,
             clean=cfg.get("clean", False),
             max_seq_length=cfg.get("max_seq_length", 512),
             min_seq_length=cfg.get("min_seq_length", 1),
             max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
             max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
             cache_ids=cfg.get("cache_ids", False),
             cache_data_per_node=cfg.get("cache_data_per_node", False),
             use_cache=cfg.get("use_cache", False),
             reverse_lang_direction=cfg.get("reverse_lang_direction",
                                            False),
         )
         dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)
     if cfg.shuffle:
         sampler = pt_data.RandomSampler(dataset)
     else:
         sampler = pt_data.SequentialSampler(dataset)
     return torch.utils.data.DataLoader(
         dataset=dataset,
         batch_size=1,
         sampler=sampler,
         num_workers=cfg.get("num_workers", 2),
         pin_memory=cfg.get("pin_memory", False),
         drop_last=cfg.get("drop_last", False),
     )
예제 #4
0
    def _setup_eval_dataloader_from_config(self, cfg: DictConfig):
        src_file_name = cfg.get('src_file_name')
        tgt_file_name = cfg.get('tgt_file_name')

        if src_file_name is None or tgt_file_name is None:
            raise ValueError(
                'Validation dataloader needs both cfg.src_file_name and cfg.tgt_file_name to not be None.'
            )
        else:
            # convert src_file_name and tgt_file_name to list of strings
            if isinstance(src_file_name, str):
                src_file_list = [src_file_name]
            elif isinstance(src_file_name, ListConfig):
                src_file_list = src_file_name
            else:
                raise ValueError(
                    "cfg.src_file_name must be string or list of strings")
            if isinstance(tgt_file_name, str):
                tgt_file_list = [tgt_file_name]
            elif isinstance(tgt_file_name, ListConfig):
                tgt_file_list = tgt_file_name
            else:
                raise ValueError(
                    "cfg.tgt_file_name must be string or list of strings")
        if len(src_file_list) != len(tgt_file_list):
            raise ValueError(
                'The same number of filepaths must be passed in for source and target validation.'
            )

        dataloaders = []
        prepend_idx = 0
        for idx, src_file in enumerate(src_file_list):
            if self.multilingual:
                prepend_idx = idx
            dataset = TranslationDataset(
                dataset_src=str(Path(src_file).expanduser()),
                dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()),
                tokens_in_batch=cfg.tokens_in_batch,
                clean=cfg.get("clean", False),
                max_seq_length=cfg.get("max_seq_length", 512),
                min_seq_length=cfg.get("min_seq_length", 1),
                max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
                max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
                cache_ids=cfg.get("cache_ids", False),
                cache_data_per_node=cfg.get("cache_data_per_node", False),
                use_cache=cfg.get("use_cache", False),
                reverse_lang_direction=cfg.get("reverse_lang_direction",
                                               False),
                prepend_id=self.multilingual_ids[prepend_idx]
                if self.multilingual else None,
            )
            dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)

            if cfg.shuffle:
                sampler = pt_data.RandomSampler(dataset)
            else:
                sampler = pt_data.SequentialSampler(dataset)

            dataloader = torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                sampler=sampler,
                num_workers=cfg.get("num_workers", 2),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )
            dataloaders.append(dataloader)

        return dataloaders
예제 #5
0
    def _setup_dataloader_from_config(self, cfg: DictConfig):
        if cfg.get("use_tarred_dataset", False):
            if cfg.get("metadata_file") is None:
                raise FileNotFoundError(
                    "Trying to use tarred data set but could not find metadata path in config."
                )
            metadata_file_list = cfg.get('metadata_file')
            tar_files_list = cfg.get('tar_files', None)
            if isinstance(metadata_file_list, str):
                metadata_file_list = [metadata_file_list]
            if tar_files_list is not None and isinstance(tar_files_list, str):
                tar_files_list = [tar_files_list]
            if tar_files_list is not None and len(tar_files_list) != len(
                    metadata_file_list):
                raise ValueError(
                    'The config must have the same number of tarfile paths and metadata file paths.'
                )

            datasets = []
            for idx, metadata_file in enumerate(metadata_file_list):
                with open(metadata_file) as metadata_reader:
                    metadata = json.load(metadata_reader)
                if tar_files_list is None:
                    tar_files = metadata.get('tar_files')
                    if tar_files is not None:
                        logging.info(
                            f'Loading from tarred dataset {tar_files}')
                else:
                    tar_files = tar_files_list[idx]
                    if metadata.get('tar_files') is not None:
                        logging.info(
                            f'Tar file paths found in both cfg and metadata using one in cfg by default - {tar_files}'
                        )

                dataset = TarredTranslationDataset(
                    text_tar_filepaths=tar_files,
                    metadata_path=metadata_file,
                    encoder_tokenizer=self.encoder_tokenizer,
                    decoder_tokenizer=self.decoder_tokenizer,
                    shuffle_n=cfg.get("tar_shuffle_n", 100),
                    shard_strategy=cfg.get("shard_strategy", "scatter"),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                    reverse_lang_direction=cfg.get("reverse_lang_direction",
                                                   False),
                    prepend_id=self.multilingual_ids[idx]
                    if self.multilingual else None,
                )
                datasets.append(dataset)

            if len(datasets) > 1:
                dataset = ConcatDataset(
                    datasets=datasets,
                    sampling_technique=cfg.get('concat_sampling_technique'),
                    sampling_temperature=cfg.get(
                        'concat_sampling_temperature'),
                    sampling_probabilities=cfg.get(
                        'concat_sampling_probabilities'),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                )
            else:
                dataset = datasets[0]
        else:
            src_file_list = cfg.src_file_name
            tgt_file_list = cfg.tgt_file_name
            if isinstance(src_file_list, str):
                src_file_list = [src_file_list]
            if isinstance(tgt_file_list, str):
                tgt_file_list = [tgt_file_list]

            if len(src_file_list) != len(tgt_file_list):
                raise ValueError(
                    'The same number of filepaths must be passed in for source and target.'
                )

            datasets = []
            for idx, src_file in enumerate(src_file_list):
                dataset = TranslationDataset(
                    dataset_src=str(Path(src_file).expanduser()),
                    dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()),
                    tokens_in_batch=cfg.tokens_in_batch,
                    clean=cfg.get("clean", False),
                    max_seq_length=cfg.get("max_seq_length", 512),
                    min_seq_length=cfg.get("min_seq_length", 1),
                    max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
                    max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
                    cache_ids=cfg.get("cache_ids", False),
                    cache_data_per_node=cfg.get("cache_data_per_node", False),
                    use_cache=cfg.get("use_cache", False),
                    reverse_lang_direction=cfg.get("reverse_lang_direction",
                                                   False),
                    prepend_id=self.multilingual_ids[idx]
                    if self.multilingual else None,
                )
                dataset.batchify(self.encoder_tokenizer,
                                 self.decoder_tokenizer)
                datasets.append(dataset)

            if len(datasets) > 1:
                dataset = ConcatDataset(
                    datasets=datasets,
                    shuffle=cfg.get('shuffle'),
                    sampling_technique=cfg.get('concat_sampling_technique'),
                    sampling_temperature=cfg.get(
                        'concat_sampling_temperature'),
                    sampling_probabilities=cfg.get(
                        'concat_sampling_probabilities'),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                )
            else:
                dataset = datasets[0]

        if cfg.shuffle:
            sampler = pt_data.RandomSampler(dataset)
        else:
            sampler = pt_data.SequentialSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=None if cfg.get("use_tarred_dataset", False) else sampler,
            num_workers=cfg.get("num_workers", 2),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )
예제 #6
0
                                      tokenizer_model=encoder_tokenizer_model,
                                      bpe_dropout=args.bpe_dropout)

    decoder_tokenizer = get_tokenizer(tokenizer_name='yttm',
                                      tokenizer_model=decoder_tokenizer_model,
                                      bpe_dropout=args.bpe_dropout)

    tokens_in_batch = [int(item) for item in args.tokens_in_batch.split(',')]
    for num_tokens in tokens_in_batch:
        dataset = TranslationDataset(
            dataset_src=str(Path(args.src_fname).expanduser()),
            dataset_tgt=str(Path(args.tgt_fname).expanduser()),
            tokens_in_batch=num_tokens,
            clean=args.clean,
            max_seq_length=args.max_seq_length,
            min_seq_length=args.min_seq_length,
            max_seq_length_diff=args.max_seq_length,
            max_seq_length_ratio=args.max_seq_length,
            cache_ids=False,
            cache_data_per_node=False,
            use_cache=False,
        )
        print('Batchifying ...')
        dataset.batchify(encoder_tokenizer, decoder_tokenizer)
        start = time.time()
        pickle.dump(
            dataset,
            open(
                os.path.join(args.out_dir,
                             'batches.tokens.%d.pkl' % (num_tokens)), 'wb'))
        end = time.time()