Exemple #1
0
    def load_dataset(self,
                     split: str,
                     task_cfg: FairseqDataclass = None,
                     **kwargs):
        data_path = self.cfg.data
        task_cfg = task_cfg or self.cfg

        # upgrade old task
        if isinstance(task_cfg, Namespace):
            if not hasattr(task_cfg, "autoregressive"):
                task_cfg.autoregressive = not task_cfg.criterion == 'ctc'

        manifest = os.path.join(data_path, "{}.tsv".format(split))
        self.datasets[split] = FileAudioDataset(
            manifest,
            sample_rate=task_cfg.get('sample_rate', self.cfg.sample_rate),
            max_sample_size=self.cfg.max_sample_size,
            min_sample_size=self.cfg.min_sample_size,
            pad=task_cfg.labels is not None or task_cfg.enable_padding,
            normalize=task_cfg.normalize,
            num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
            compute_mask_indices=(self.cfg.precompute_mask_indices
                                  or self.cfg.tpu),
            **self._get_mask_precompute_kwargs(task_cfg),
        )

        if self.cfg.tpu and task_cfg['mask_channel_prob'] == 0.0:
            logger.info(
                "Pretraining on TPUs may suffer convergence "
                "issues when training with `mask_channel_prob` value of "
                "0. You may want to set this to a low value close to 0.")

        if task_cfg.labels:
            label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
            with open(label_path, "r") as f:
                labels = [
                    line for i, line in enumerate(f)
                    if i in self.datasets[split].line_inds
                ]

            assert len(labels) == len(self.datasets[split]), (
                f"labels length ({len(labels)}) and dataset length "
                f"({len(self.datasets[split])}) do not match")

            process_label = LabelEncoder(self.target_dictionary)

            self.datasets[split] = AddTargetDataset(
                self.datasets[split],
                labels,
                pad=self.target_dictionary.pad(),
                eos=self.target_dictionary.eos(),
                batch_targets=True,
                process_label=process_label,
                add_to_input=task_cfg.get('autoregressive', False),
            )
Exemple #2
0
    def load_dataset(self,
                     split: str,
                     task_cfg: FairseqDataclass = None,
                     **kwargs):
        data_path = self.cfg.data
        task_cfg = task_cfg or self.cfg

        # upgrade old task
        if isinstance(task_cfg, Namespace):
            if not hasattr(task_cfg, "autoregressive"):
                task_cfg.autoregressive = not task_cfg.criterion == "ctc"

        text_compression_level = getattr(TextCompressionLevel,
                                         str(self.cfg.text_compression_level))
        if getattr(task_cfg, "binarized_dataset", False):
            self.datasets[split] = BinarizedAudioDataset(
                data_path,
                split=split,
                sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
                max_sample_size=self.cfg.max_sample_size,
                min_sample_size=self.cfg.min_sample_size,
                pad=task_cfg.labels is not None or task_cfg.enable_padding,
                normalize=task_cfg.normalize,
                num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
                compute_mask_indices=(self.cfg.precompute_mask_indices
                                      or self.cfg.tpu),
                **self._get_mask_precompute_kwargs(task_cfg),
            )
        else:
            manifest_path = os.path.join(data_path, "{}.tsv".format(split))

            self.datasets[split] = FileAudioDataset(
                manifest_path=manifest_path,
                sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
                max_sample_size=self.cfg.max_sample_size,
                min_sample_size=self.cfg.min_sample_size,
                pad=task_cfg.labels is not None or task_cfg.enable_padding,
                normalize=task_cfg.normalize,
                num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
                compute_mask_indices=(self.cfg.precompute_mask_indices
                                      or self.cfg.tpu),
                text_compression_level=text_compression_level,
                **self._get_mask_precompute_kwargs(task_cfg),
            )

        if self.cfg.tpu and task_cfg.inferred_w2v_config.mask_channel_prob == 0.0:
            logger.info(
                "Pretraining on TPUs may suffer convergence "
                "issues when training with `mask_channel_prob` value of "
                "0. You may want to set this to a low value close to 0.")
Exemple #3
0
    def load_dataset(self,
                     split: str,
                     task_cfg: FairseqDataclass = None,
                     **kwargs):
        data_path = self.cfg.data
        task_cfg = task_cfg or self.cfg

        # upgrade old task
        if isinstance(task_cfg, Namespace):
            if not hasattr(task_cfg, "autoregressive"):
                task_cfg.autoregressive = not task_cfg.criterion == 'ctc'

        manifest = os.path.join(data_path, "{}.tsv".format(split))
        self.datasets[split] = FileAudioDataset(
            manifest,
            sample_rate=task_cfg.get('sample_rate', self.cfg.sample_rate),
            max_sample_size=self.cfg.max_sample_size,
            min_sample_size=self.cfg.min_sample_size,
            pad=task_cfg.labels is not None or task_cfg.enable_padding,
            normalize=task_cfg.normalize,
        )

        if task_cfg.labels:
            label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
            skipped_indices = getattr(self.datasets[split], 'skipped_indices',
                                      set())
            with open(label_path, "r") as f:
                labels = [  # encode using bpe
                    self.encode(  # convert to lower case
                        ''.join(line.split()).lower().replace('|', ' '))
                    for i, line in enumerate(f) if i not in skipped_indices
                ]

            assert len(labels) == len(self.datasets[split]), (
                f"labels length ({len(labels)}) and dataset length "
                f"({len(self.datasets[split])}) do not match")

            process_label = LabelEncoder(self.target_dictionary)

            self.datasets[split] = AddTargetDataset(
                self.datasets[split],
                labels,
                pad=self.target_dictionary.pad(),
                eos=self.target_dictionary.eos(),
                batch_targets=True,
                process_label=process_label,
                add_to_input=task_cfg.get('autoregressive', False),
            )
Exemple #4
0
    def build_model(self, model_cfg: FairseqDataclass):
        model = super().build_model(model_cfg)

        actualized_cfg = getattr(model, "cfg", None)
        if actualized_cfg is not None:
            # if "w2v_args" in actualized_cfg:
            if hasattr(actualized_cfg, "w2v_args"):
                model_cfg.w2v_args = actualized_cfg.w2v_args

        return model
Exemple #5
0
    def load_dataset(self,
                     split: str,
                     task_cfg: FairseqDataclass = None,
                     **kwargs):
        data_path = self.cfg.data
        task_cfg = task_cfg or self.cfg

        # upgrade old task
        if isinstance(task_cfg, Namespace):
            if not hasattr(task_cfg, "autoregressive"):
                task_cfg.autoregressive = not task_cfg.criterion == 'ctc'

        manifest = os.path.join(data_path, "{}.tsv".format(split))
        self.datasets[split] = FileAudioDataset(
            manifest,
            sample_rate=task_cfg.sample_rate,
            max_sample_size=self.cfg.max_sample_size,
            min_sample_size=self.cfg.max_sample_size,
            min_length=self.cfg.min_sample_size,
            pad=task_cfg.labels is not None or task_cfg.enable_padding,
            normalize=task_cfg.normalize,
            num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
            compute_mask_indices=(self.cfg.precompute_mask_indices
                                  or self.cfg.tpu),
            **self._get_mask_precompute_kwargs(task_cfg),
        )

        if task_cfg.labels:
            label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
            labels = []
            with open(label_path, "r") as f:
                labels = [
                    line for i, line in enumerate(f)
                    if i in self.datasets[split].line_inds
                ]

            assert len(labels) == len(self.datasets[split]), (
                f"labels length ({len(labels)}) and dataset length "
                f"({len(self.datasets[split])}) do not match")

            process_label = LabelEncoder(self.target_dictionary)

            self.datasets[split] = AddTargetDataset(
                self.datasets[split],
                labels,
                pad=self.target_dictionary.pad(),
                eos=self.target_dictionary.eos(),
                batch_targets=True,
                process_label=process_label,
                add_to_input=task_cfg.autoregressive,
            )
Exemple #6
0
def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=True):
    if remove_missing:

        if is_dataclass(dc):
            target_keys = set(dc.__dataclass_fields__.keys())
        else:
            target_keys = set(dc.keys())

        with open_dict(cfg):
            for k in list(cfg.keys()):
                if k not in target_keys:
                    del cfg[k]

    merged_cfg = OmegaConf.merge(dc, cfg)
    merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"]
    OmegaConf.set_struct(merged_cfg, True)
    return merged_cfg
    def build_model(self, model_cfg: FairseqDataclass):
        model = super().build_model(model_cfg)

        if self.cfg.eval_wer and self.cfg.autoregressive:
            self.sequence_generator = self.build_generator(
                [model],
                self.cfg.eval_wer_config,
            )
            if self.cfg.eval_wer_tokenizer:
                self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer)
            else:
                self.tokenizer = None

        actualized_cfg = getattr(model, "cfg", None)
        if actualized_cfg is not None:
            if "w2v_args" in actualized_cfg:
                model_cfg.w2v_args = actualized_cfg.w2v_args

        return model
    def load_dataset(self,
                     split: str,
                     task_cfg: FairseqDataclass = None,
                     **kwargs):
        data_path = self.cfg.data
        task_cfg = task_cfg or self.cfg

        # upgrade old task
        if isinstance(task_cfg, Namespace):
            if not hasattr(task_cfg, "autoregressive"):
                task_cfg.autoregressive = not task_cfg.criterion == 'ctc'

        manifest = os.path.join(data_path, "{}.tsv".format(split))
        self.datasets[split] = FileAudioDataset(
            manifest,
            sample_rate=task_cfg.sample_rate,
            max_sample_size=self.cfg.max_sample_size,
            min_sample_size=self.cfg.max_sample_size,
            min_length=self.cfg.min_sample_size,
            pad=task_cfg.labels is not None or task_cfg.enable_padding,
            normalize=task_cfg.normalize,
        )

        if task_cfg.labels:
            label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
            labels = []
            with open(label_path, "r") as f:
                for line in f:
                    labels.append(line)

            process_label = LabelEncoder(self.target_dictionary)

            self.datasets[split] = AddTargetDataset(
                self.datasets[split],
                labels,
                pad=self.target_dictionary.pad(),
                eos=self.target_dictionary.eos(),
                batch_targets=True,
                process_label=process_label,
                add_to_input=task_cfg.autoregressive,
            )
Exemple #9
0
    def get_kwargs_from_dc(dataclass_instance: FairseqDataclass,
                           k: str) -> Dict[str, Any]:
        """k: dataclass attributes"""

        kwargs = {}

        field_type = dataclass_instance._get_type(k)
        inter_type = interpret_dc_type(field_type)

        field_default = dataclass_instance._get_default(k)

        if isinstance(inter_type, type) and issubclass(inter_type, Enum):
            field_choices = [t.value for t in list(inter_type)]
        else:
            field_choices = None

        field_help = dataclass_instance._get_help(k)
        field_const = dataclass_instance._get_argparse_const(k)

        if isinstance(field_default, str) and field_default.startswith("${"):
            kwargs["default"] = field_default
        else:
            if field_default is MISSING:
                kwargs["required"] = True
            if field_choices is not None:
                kwargs["choices"] = field_choices
            if (isinstance(inter_type, type) and
                (issubclass(inter_type, List) or issubclass(inter_type, Tuple))
                ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)):
                if "int" in str(inter_type):
                    kwargs["type"] = lambda x: eval_str_list(x, int)
                elif "float" in str(inter_type):
                    kwargs["type"] = lambda x: eval_str_list(x, float)
                elif "str" in str(inter_type):
                    kwargs["type"] = lambda x: eval_str_list(x, str)
                else:
                    raise NotImplementedError("parsing of type " +
                                              str(inter_type) +
                                              " is not implemented")
                if field_default is not MISSING:
                    kwargs["default"] = (",".join(map(str, field_default)) if
                                         field_default is not None else None)
            elif (isinstance(inter_type, type) and issubclass(
                    inter_type, Enum)) or "Enum" in str(inter_type):
                kwargs["type"] = str
                if field_default is not MISSING:
                    if isinstance(field_default, Enum):
                        kwargs["default"] = field_default.value
                    else:
                        kwargs["default"] = field_default
            elif inter_type is bool:
                kwargs["action"] = ("store_false"
                                    if field_default is True else "store_true")
                kwargs["default"] = field_default
            else:
                kwargs["type"] = inter_type
                if field_default is not MISSING:
                    kwargs["default"] = field_default

        kwargs["help"] = field_help
        if field_const is not None:
            kwargs["const"] = field_const
            kwargs["nargs"] = "?"

        return kwargs
Exemple #10
0
def gen_parser_from_dataclass(
    parser: ArgumentParser,
    dataclass_instance: FairseqDataclass,
    delete_default: bool = False,
) -> None:
    """convert a dataclass instance to tailing parser arguments"""
    def argparse_name(name: str):
        if name == "data":
            # normally data is positional args
            return name
        if name == "_name":
            # private member, skip
            return None
        return "--" + name.replace("_", "-")

    def get_kwargs_from_dc(dataclass_instance: FairseqDataclass,
                           k: str) -> Dict[str, Any]:
        """k: dataclass attributes"""

        kwargs = {}

        field_type = dataclass_instance._get_type(k)
        inter_type = interpret_dc_type(field_type)

        field_default = dataclass_instance._get_default(k)

        if isinstance(inter_type, type) and issubclass(inter_type, Enum):
            field_choices = [t.value for t in list(inter_type)]
        else:
            field_choices = None

        field_help = dataclass_instance._get_help(k)
        field_const = dataclass_instance._get_argparse_const(k)

        if isinstance(field_default, str) and field_default.startswith("${"):
            kwargs["default"] = field_default
        else:
            if field_default is MISSING:
                kwargs["required"] = True
            if field_choices is not None:
                kwargs["choices"] = field_choices
            if (isinstance(inter_type, type) and
                (issubclass(inter_type, List) or issubclass(inter_type, Tuple))
                ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)):
                if "int" in str(inter_type):
                    kwargs["type"] = lambda x: eval_str_list(x, int)
                elif "float" in str(inter_type):
                    kwargs["type"] = lambda x: eval_str_list(x, float)
                elif "str" in str(inter_type):
                    kwargs["type"] = lambda x: eval_str_list(x, str)
                else:
                    raise NotImplementedError("parsing of type " +
                                              str(inter_type) +
                                              " is not implemented")
                if field_default is not MISSING:
                    kwargs["default"] = (",".join(map(str, field_default)) if
                                         field_default is not None else None)
            elif (isinstance(inter_type, type) and issubclass(
                    inter_type, Enum)) or "Enum" in str(inter_type):
                kwargs["type"] = str
                if field_default is not MISSING:
                    if isinstance(field_default, Enum):
                        kwargs["default"] = field_default.value
                    else:
                        kwargs["default"] = field_default
            elif inter_type is bool:
                kwargs["action"] = ("store_false"
                                    if field_default is True else "store_true")
                kwargs["default"] = field_default
            else:
                kwargs["type"] = inter_type
                if field_default is not MISSING:
                    kwargs["default"] = field_default

        kwargs["help"] = field_help
        if field_const is not None:
            kwargs["const"] = field_const
            kwargs["nargs"] = "?"

        return kwargs

    for k in dataclass_instance._get_all_attributes():
        field_name = argparse_name(dataclass_instance._get_name(k))
        field_type = dataclass_instance._get_type(k)
        if field_name is None:
            continue
        elif inspect.isclass(field_type) and issubclass(
                field_type, FairseqDataclass):
            gen_parser_from_dataclass(parser, field_type(), delete_default)
            continue

        kwargs = get_kwargs_from_dc(dataclass_instance, k)

        field_args = [field_name]
        alias = dataclass_instance._get_argparse_alias(k)
        if alias is not None:
            field_args.append(alias)

        if "default" in kwargs:
            if isinstance(kwargs["default"],
                          str) and kwargs["default"].startswith("${"):
                if kwargs["help"] is None:
                    # this is a field with a name that will be added elsewhere
                    continue
                else:
                    del kwargs["default"]
            if delete_default:
                del kwargs["default"]
        try:
            parser.add_argument(*field_args, **kwargs)
        except ArgumentError:
            pass
Exemple #11
0
def gen_parser_from_dataclass(
    parser: ArgumentParser,
    dataclass_instance: FairseqDataclass,
    delete_default: bool = False,
    with_prefix: Optional[str] = None,
) -> None:
    """
        convert a dataclass instance to tailing parser arguments.

        If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are
        building a flat namespace from a structured dataclass (see transformer_config.py for example).
    """
    def argparse_name(name: str):
        if name == "data" and (with_prefix is None or with_prefix == ''):
            # normally data is positional args, so we don't add the -- nor the prefix
            return name
        if name == "_name":
            # private member, skip
            return None
        full_name = "--" + name.replace("_", "-")
        if with_prefix is not None and with_prefix != '':
            # if a prefix is specified, construct the prefixed arg name
            full_name = with_prefix + "-" + full_name[
                2:]  # strip -- when composing
        return full_name

    def get_kwargs_from_dc(dataclass_instance: FairseqDataclass,
                           k: str) -> Dict[str, Any]:
        """k: dataclass attributes"""

        kwargs = {}

        field_type = dataclass_instance._get_type(k)
        inter_type = interpret_dc_type(field_type)

        field_default = dataclass_instance._get_default(k)

        if isinstance(inter_type, type) and issubclass(inter_type, Enum):
            field_choices = [t.value for t in list(inter_type)]
        else:
            field_choices = None

        field_help = dataclass_instance._get_help(k)
        field_const = dataclass_instance._get_argparse_const(k)

        if isinstance(field_default, str) and field_default.startswith("${"):
            kwargs["default"] = field_default
        else:
            if field_default is MISSING:
                kwargs["required"] = True
            if field_choices is not None:
                kwargs["choices"] = field_choices
            if (isinstance(inter_type, type) and
                (issubclass(inter_type, List) or issubclass(inter_type, Tuple))
                ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)):
                if "int" in str(inter_type):
                    kwargs["type"] = lambda x: eval_str_list(x, int)
                elif "float" in str(inter_type):
                    kwargs["type"] = lambda x: eval_str_list(x, float)
                elif "str" in str(inter_type):
                    kwargs["type"] = lambda x: eval_str_list(x, str)
                else:
                    raise NotImplementedError("parsing of type " +
                                              str(inter_type) +
                                              " is not implemented")
                if field_default is not MISSING:
                    kwargs["default"] = (",".join(map(str, field_default)) if
                                         field_default is not None else None)
            elif (isinstance(inter_type, type) and issubclass(
                    inter_type, Enum)) or "Enum" in str(inter_type):
                kwargs["type"] = str
                if field_default is not MISSING:
                    if isinstance(field_default, Enum):
                        kwargs["default"] = field_default.value
                    else:
                        kwargs["default"] = field_default
            elif inter_type is bool:
                kwargs["action"] = ("store_false"
                                    if field_default is True else "store_true")
                kwargs["default"] = field_default
            else:
                kwargs["type"] = inter_type
                if field_default is not MISSING:
                    kwargs["default"] = field_default

        # build the help with the hierarchical prefix
        if with_prefix is not None and with_prefix != '' and field_help is not None:
            field_help = with_prefix[2:] + ': ' + field_help

        kwargs["help"] = field_help
        if field_const is not None:
            kwargs["const"] = field_const
            kwargs["nargs"] = "?"

        return kwargs

    for k in dataclass_instance._get_all_attributes():
        field_name = argparse_name(dataclass_instance._get_name(k))
        field_type = dataclass_instance._get_type(k)
        if field_name is None:
            continue
        elif inspect.isclass(field_type) and issubclass(
                field_type, FairseqDataclass):
            # for fields that are of type FairseqDataclass, we can recursively
            # add their fields to the namespace (so we add the args from model, task, etc. to the root namespace)
            prefix = None
            if with_prefix is not None:
                # if a prefix is specified, then we don't want to copy the subfields directly to the root namespace
                # but we prefix them with the name of the current field.
                prefix = field_name
            gen_parser_from_dataclass(parser, field_type(), delete_default,
                                      prefix)
            continue

        kwargs = get_kwargs_from_dc(dataclass_instance, k)

        field_args = [field_name]
        alias = dataclass_instance._get_argparse_alias(k)
        if alias is not None:
            field_args.append(alias)

        if "default" in kwargs:
            if isinstance(kwargs["default"],
                          str) and kwargs["default"].startswith("${"):
                if kwargs["help"] is None:
                    # this is a field with a name that will be added elsewhere
                    continue
                else:
                    del kwargs["default"]
            if delete_default and "default" in kwargs:
                del kwargs["default"]
        try:
            parser.add_argument(*field_args, **kwargs)
        except ArgumentError:
            pass
    def load_dataset(self,
                     split: str,
                     task_cfg: FairseqDataclass = None,
                     **kwargs):
        data_path_parent = self.cfg.data
        task_cfg = task_cfg or self.cfg
        data_path_list = [
            os.path.join(data_path_parent, path)
            for path in os.listdir(data_path_parent)
        ]

        # upgrade old task
        if isinstance(task_cfg, Namespace):
            if not hasattr(task_cfg, "autoregressive"):
                task_cfg.autoregressive = not task_cfg.criterion == "ctc"

        dataset_map = OrderedDict()
        datasets_lengths = []
        for data_path in data_path_list:
            if getattr(task_cfg, "binarized_dataset", False):
                dataset_map[data_path] = BinarizedAudioDataset(
                    data_path,
                    split=split,
                    sample_rate=task_cfg.get("sample_rate",
                                             self.cfg.sample_rate),
                    max_sample_size=self.cfg.max_sample_size,
                    min_sample_size=self.cfg.min_sample_size,
                    pad=task_cfg.labels is not None or task_cfg.enable_padding,
                    normalize=task_cfg.normalize,
                    num_buckets=self.cfg.num_batch_buckets
                    or int(self.cfg.tpu),
                    compute_mask_indices=(self.cfg.precompute_mask_indices
                                          or self.cfg.tpu),
                    **self._get_mask_precompute_kwargs(task_cfg),
                )
            else:
                manifest_path = os.path.join(data_path, "{}.tsv".format(split))

                dataset_map[data_path] = FileAudioDataset(
                    manifest_path=manifest_path,
                    sample_rate=task_cfg.get("sample_rate",
                                             self.cfg.sample_rate),
                    max_sample_size=self.cfg.max_sample_size,
                    min_sample_size=self.cfg.min_sample_size,
                    pad=task_cfg.labels is not None or task_cfg.enable_padding,
                    normalize=task_cfg.normalize,
                    num_buckets=self.cfg.num_batch_buckets
                    or int(self.cfg.tpu),
                    compute_mask_indices=(self.cfg.precompute_mask_indices
                                          or self.cfg.tpu),
                    **self._get_mask_precompute_kwargs(task_cfg),
                )

            if self.cfg.tpu and task_cfg["mask_channel_prob"] == 0.0:
                logger.info(
                    "Pretraining on TPUs may suffer convergence "
                    "issues when training with `mask_channel_prob` value of "
                    "0. You may want to set this to a low value close to 0.")

            if task_cfg.labels:
                label_path = os.path.join(data_path,
                                          f"{split}.{task_cfg.labels}")
                if os.path.exists(label_path):
                    skipped_indices = getattr(dataset_map[data_path],
                                              "skipped_indices", set())

                    with open(label_path, "r") as f:
                        labels = [
                            line for i, line in enumerate(f)
                            if i not in skipped_indices
                        ]

                    assert len(labels) == len(dataset_map[data_path]), (
                        f"labels length ({len(labels)}) and dataset length "
                        f"({len(dataset_map[data_path])}) do not match")

                    process_label = LabelEncoder(self.target_dictionary)

                    dataset_map[data_path] = AddTargetDataset(
                        dataset_map[data_path],
                        labels,
                        pad=self.target_dictionary.pad(),
                        eos=self.target_dictionary.eos(),
                        batch_targets=True,
                        process_label=process_label,
                        add_to_input=task_cfg.get("autoregressive", False),
                    )

            datasets_lengths.append(
                sum(dataset_map[data_path].sizes) / task_cfg.sample_rate /
                3600)

        datasets_lengths = np.array(datasets_lengths)
        self.sample_probs = self._get_sample_prob(datasets_lengths)
        size_ratio = (self.sample_probs *
                      datasets_lengths.sum()) / datasets_lengths
        for id, data_path in enumerate(data_path_list):
            logger.info(
                "Up/Down Sampling ratio by datasets: {} : {:.2f} to prob:{:.2f}".\
                    format(data_path.split('/')[-1], size_ratio[id],self.sample_probs[id])
            )

        self.datasets[split] = MultiCorpusSampledDataset(
            dataset_map, sampling_func=self.dataset_sampler)
        logger.info('{} {} examples'.format(split, len(self.datasets[split])))