def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): data_path = self.cfg.data task_cfg = task_cfg or self.cfg # upgrade old task if isinstance(task_cfg, Namespace): if not hasattr(task_cfg, "autoregressive"): task_cfg.autoregressive = not task_cfg.criterion == 'ctc' manifest = os.path.join(data_path, "{}.tsv".format(split)) self.datasets[split] = FileAudioDataset( manifest, sample_rate=task_cfg.get('sample_rate', self.cfg.sample_rate), max_sample_size=self.cfg.max_sample_size, min_sample_size=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu), **self._get_mask_precompute_kwargs(task_cfg), ) if self.cfg.tpu and task_cfg['mask_channel_prob'] == 0.0: logger.info( "Pretraining on TPUs may suffer convergence " "issues when training with `mask_channel_prob` value of " "0. You may want to set this to a low value close to 0.") if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") with open(label_path, "r") as f: labels = [ line for i, line in enumerate(f) if i in self.datasets[split].line_inds ] assert len(labels) == len(self.datasets[split]), ( f"labels length ({len(labels)}) and dataset length " f"({len(self.datasets[split])}) do not match") process_label = LabelEncoder(self.target_dictionary) self.datasets[split] = AddTargetDataset( self.datasets[split], labels, pad=self.target_dictionary.pad(), eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, add_to_input=task_cfg.get('autoregressive', False), )
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): data_path = self.cfg.data task_cfg = task_cfg or self.cfg # upgrade old task if isinstance(task_cfg, Namespace): if not hasattr(task_cfg, "autoregressive"): task_cfg.autoregressive = not task_cfg.criterion == "ctc" text_compression_level = getattr(TextCompressionLevel, str(self.cfg.text_compression_level)) if getattr(task_cfg, "binarized_dataset", False): self.datasets[split] = BinarizedAudioDataset( data_path, split=split, sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate), max_sample_size=self.cfg.max_sample_size, min_sample_size=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu), **self._get_mask_precompute_kwargs(task_cfg), ) else: manifest_path = os.path.join(data_path, "{}.tsv".format(split)) self.datasets[split] = FileAudioDataset( manifest_path=manifest_path, sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate), max_sample_size=self.cfg.max_sample_size, min_sample_size=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu), text_compression_level=text_compression_level, **self._get_mask_precompute_kwargs(task_cfg), ) if self.cfg.tpu and task_cfg.inferred_w2v_config.mask_channel_prob == 0.0: logger.info( "Pretraining on TPUs may suffer convergence " "issues when training with `mask_channel_prob` value of " "0. You may want to set this to a low value close to 0.")
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): data_path = self.cfg.data task_cfg = task_cfg or self.cfg # upgrade old task if isinstance(task_cfg, Namespace): if not hasattr(task_cfg, "autoregressive"): task_cfg.autoregressive = not task_cfg.criterion == 'ctc' manifest = os.path.join(data_path, "{}.tsv".format(split)) self.datasets[split] = FileAudioDataset( manifest, sample_rate=task_cfg.get('sample_rate', self.cfg.sample_rate), max_sample_size=self.cfg.max_sample_size, min_sample_size=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, ) if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") skipped_indices = getattr(self.datasets[split], 'skipped_indices', set()) with open(label_path, "r") as f: labels = [ # encode using bpe self.encode( # convert to lower case ''.join(line.split()).lower().replace('|', ' ')) for i, line in enumerate(f) if i not in skipped_indices ] assert len(labels) == len(self.datasets[split]), ( f"labels length ({len(labels)}) and dataset length " f"({len(self.datasets[split])}) do not match") process_label = LabelEncoder(self.target_dictionary) self.datasets[split] = AddTargetDataset( self.datasets[split], labels, pad=self.target_dictionary.pad(), eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, add_to_input=task_cfg.get('autoregressive', False), )
def build_model(self, model_cfg: FairseqDataclass): model = super().build_model(model_cfg) actualized_cfg = getattr(model, "cfg", None) if actualized_cfg is not None: # if "w2v_args" in actualized_cfg: if hasattr(actualized_cfg, "w2v_args"): model_cfg.w2v_args = actualized_cfg.w2v_args return model
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): data_path = self.cfg.data task_cfg = task_cfg or self.cfg # upgrade old task if isinstance(task_cfg, Namespace): if not hasattr(task_cfg, "autoregressive"): task_cfg.autoregressive = not task_cfg.criterion == 'ctc' manifest = os.path.join(data_path, "{}.tsv".format(split)) self.datasets[split] = FileAudioDataset( manifest, sample_rate=task_cfg.sample_rate, max_sample_size=self.cfg.max_sample_size, min_sample_size=self.cfg.max_sample_size, min_length=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu), **self._get_mask_precompute_kwargs(task_cfg), ) if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") labels = [] with open(label_path, "r") as f: labels = [ line for i, line in enumerate(f) if i in self.datasets[split].line_inds ] assert len(labels) == len(self.datasets[split]), ( f"labels length ({len(labels)}) and dataset length " f"({len(self.datasets[split])}) do not match") process_label = LabelEncoder(self.target_dictionary) self.datasets[split] = AddTargetDataset( self.datasets[split], labels, pad=self.target_dictionary.pad(), eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, add_to_input=task_cfg.autoregressive, )
def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=True): if remove_missing: if is_dataclass(dc): target_keys = set(dc.__dataclass_fields__.keys()) else: target_keys = set(dc.keys()) with open_dict(cfg): for k in list(cfg.keys()): if k not in target_keys: del cfg[k] merged_cfg = OmegaConf.merge(dc, cfg) merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] OmegaConf.set_struct(merged_cfg, True) return merged_cfg
def build_model(self, model_cfg: FairseqDataclass): model = super().build_model(model_cfg) if self.cfg.eval_wer and self.cfg.autoregressive: self.sequence_generator = self.build_generator( [model], self.cfg.eval_wer_config, ) if self.cfg.eval_wer_tokenizer: self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) else: self.tokenizer = None actualized_cfg = getattr(model, "cfg", None) if actualized_cfg is not None: if "w2v_args" in actualized_cfg: model_cfg.w2v_args = actualized_cfg.w2v_args return model
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): data_path = self.cfg.data task_cfg = task_cfg or self.cfg # upgrade old task if isinstance(task_cfg, Namespace): if not hasattr(task_cfg, "autoregressive"): task_cfg.autoregressive = not task_cfg.criterion == 'ctc' manifest = os.path.join(data_path, "{}.tsv".format(split)) self.datasets[split] = FileAudioDataset( manifest, sample_rate=task_cfg.sample_rate, max_sample_size=self.cfg.max_sample_size, min_sample_size=self.cfg.max_sample_size, min_length=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, ) if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") labels = [] with open(label_path, "r") as f: for line in f: labels.append(line) process_label = LabelEncoder(self.target_dictionary) self.datasets[split] = AddTargetDataset( self.datasets[split], labels, pad=self.target_dictionary.pad(), eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, add_to_input=task_cfg.autoregressive, )
def get_kwargs_from_dc(dataclass_instance: FairseqDataclass, k: str) -> Dict[str, Any]: """k: dataclass attributes""" kwargs = {} field_type = dataclass_instance._get_type(k) inter_type = interpret_dc_type(field_type) field_default = dataclass_instance._get_default(k) if isinstance(inter_type, type) and issubclass(inter_type, Enum): field_choices = [t.value for t in list(inter_type)] else: field_choices = None field_help = dataclass_instance._get_help(k) field_const = dataclass_instance._get_argparse_const(k) if isinstance(field_default, str) and field_default.startswith("${"): kwargs["default"] = field_default else: if field_default is MISSING: kwargs["required"] = True if field_choices is not None: kwargs["choices"] = field_choices if (isinstance(inter_type, type) and (issubclass(inter_type, List) or issubclass(inter_type, Tuple)) ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)): if "int" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, int) elif "float" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, float) elif "str" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, str) else: raise NotImplementedError("parsing of type " + str(inter_type) + " is not implemented") if field_default is not MISSING: kwargs["default"] = (",".join(map(str, field_default)) if field_default is not None else None) elif (isinstance(inter_type, type) and issubclass( inter_type, Enum)) or "Enum" in str(inter_type): kwargs["type"] = str if field_default is not MISSING: if isinstance(field_default, Enum): kwargs["default"] = field_default.value else: kwargs["default"] = field_default elif inter_type is bool: kwargs["action"] = ("store_false" if field_default is True else "store_true") kwargs["default"] = field_default else: kwargs["type"] = inter_type if field_default is not MISSING: kwargs["default"] = field_default kwargs["help"] = field_help if field_const is not None: kwargs["const"] = field_const kwargs["nargs"] = "?" return kwargs
def gen_parser_from_dataclass( parser: ArgumentParser, dataclass_instance: FairseqDataclass, delete_default: bool = False, ) -> None: """convert a dataclass instance to tailing parser arguments""" def argparse_name(name: str): if name == "data": # normally data is positional args return name if name == "_name": # private member, skip return None return "--" + name.replace("_", "-") def get_kwargs_from_dc(dataclass_instance: FairseqDataclass, k: str) -> Dict[str, Any]: """k: dataclass attributes""" kwargs = {} field_type = dataclass_instance._get_type(k) inter_type = interpret_dc_type(field_type) field_default = dataclass_instance._get_default(k) if isinstance(inter_type, type) and issubclass(inter_type, Enum): field_choices = [t.value for t in list(inter_type)] else: field_choices = None field_help = dataclass_instance._get_help(k) field_const = dataclass_instance._get_argparse_const(k) if isinstance(field_default, str) and field_default.startswith("${"): kwargs["default"] = field_default else: if field_default is MISSING: kwargs["required"] = True if field_choices is not None: kwargs["choices"] = field_choices if (isinstance(inter_type, type) and (issubclass(inter_type, List) or issubclass(inter_type, Tuple)) ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)): if "int" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, int) elif "float" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, float) elif "str" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, str) else: raise NotImplementedError("parsing of type " + str(inter_type) + " is not implemented") if field_default is not MISSING: kwargs["default"] = (",".join(map(str, field_default)) if field_default is not None else None) elif (isinstance(inter_type, type) and issubclass( inter_type, Enum)) or "Enum" in str(inter_type): kwargs["type"] = str if field_default is not MISSING: if isinstance(field_default, Enum): kwargs["default"] = field_default.value else: kwargs["default"] = field_default elif inter_type is bool: kwargs["action"] = ("store_false" if field_default is True else "store_true") kwargs["default"] = field_default else: kwargs["type"] = inter_type if field_default is not MISSING: kwargs["default"] = field_default kwargs["help"] = field_help if field_const is not None: kwargs["const"] = field_const kwargs["nargs"] = "?" return kwargs for k in dataclass_instance._get_all_attributes(): field_name = argparse_name(dataclass_instance._get_name(k)) field_type = dataclass_instance._get_type(k) if field_name is None: continue elif inspect.isclass(field_type) and issubclass( field_type, FairseqDataclass): gen_parser_from_dataclass(parser, field_type(), delete_default) continue kwargs = get_kwargs_from_dc(dataclass_instance, k) field_args = [field_name] alias = dataclass_instance._get_argparse_alias(k) if alias is not None: field_args.append(alias) if "default" in kwargs: if isinstance(kwargs["default"], str) and kwargs["default"].startswith("${"): if kwargs["help"] is None: # this is a field with a name that will be added elsewhere continue else: del kwargs["default"] if delete_default: del kwargs["default"] try: parser.add_argument(*field_args, **kwargs) except ArgumentError: pass
def gen_parser_from_dataclass( parser: ArgumentParser, dataclass_instance: FairseqDataclass, delete_default: bool = False, with_prefix: Optional[str] = None, ) -> None: """ convert a dataclass instance to tailing parser arguments. If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are building a flat namespace from a structured dataclass (see transformer_config.py for example). """ def argparse_name(name: str): if name == "data" and (with_prefix is None or with_prefix == ''): # normally data is positional args, so we don't add the -- nor the prefix return name if name == "_name": # private member, skip return None full_name = "--" + name.replace("_", "-") if with_prefix is not None and with_prefix != '': # if a prefix is specified, construct the prefixed arg name full_name = with_prefix + "-" + full_name[ 2:] # strip -- when composing return full_name def get_kwargs_from_dc(dataclass_instance: FairseqDataclass, k: str) -> Dict[str, Any]: """k: dataclass attributes""" kwargs = {} field_type = dataclass_instance._get_type(k) inter_type = interpret_dc_type(field_type) field_default = dataclass_instance._get_default(k) if isinstance(inter_type, type) and issubclass(inter_type, Enum): field_choices = [t.value for t in list(inter_type)] else: field_choices = None field_help = dataclass_instance._get_help(k) field_const = dataclass_instance._get_argparse_const(k) if isinstance(field_default, str) and field_default.startswith("${"): kwargs["default"] = field_default else: if field_default is MISSING: kwargs["required"] = True if field_choices is not None: kwargs["choices"] = field_choices if (isinstance(inter_type, type) and (issubclass(inter_type, List) or issubclass(inter_type, Tuple)) ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)): if "int" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, int) elif "float" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, float) elif "str" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, str) else: raise NotImplementedError("parsing of type " + str(inter_type) + " is not implemented") if field_default is not MISSING: kwargs["default"] = (",".join(map(str, field_default)) if field_default is not None else None) elif (isinstance(inter_type, type) and issubclass( inter_type, Enum)) or "Enum" in str(inter_type): kwargs["type"] = str if field_default is not MISSING: if isinstance(field_default, Enum): kwargs["default"] = field_default.value else: kwargs["default"] = field_default elif inter_type is bool: kwargs["action"] = ("store_false" if field_default is True else "store_true") kwargs["default"] = field_default else: kwargs["type"] = inter_type if field_default is not MISSING: kwargs["default"] = field_default # build the help with the hierarchical prefix if with_prefix is not None and with_prefix != '' and field_help is not None: field_help = with_prefix[2:] + ': ' + field_help kwargs["help"] = field_help if field_const is not None: kwargs["const"] = field_const kwargs["nargs"] = "?" return kwargs for k in dataclass_instance._get_all_attributes(): field_name = argparse_name(dataclass_instance._get_name(k)) field_type = dataclass_instance._get_type(k) if field_name is None: continue elif inspect.isclass(field_type) and issubclass( field_type, FairseqDataclass): # for fields that are of type FairseqDataclass, we can recursively # add their fields to the namespace (so we add the args from model, task, etc. to the root namespace) prefix = None if with_prefix is not None: # if a prefix is specified, then we don't want to copy the subfields directly to the root namespace # but we prefix them with the name of the current field. prefix = field_name gen_parser_from_dataclass(parser, field_type(), delete_default, prefix) continue kwargs = get_kwargs_from_dc(dataclass_instance, k) field_args = [field_name] alias = dataclass_instance._get_argparse_alias(k) if alias is not None: field_args.append(alias) if "default" in kwargs: if isinstance(kwargs["default"], str) and kwargs["default"].startswith("${"): if kwargs["help"] is None: # this is a field with a name that will be added elsewhere continue else: del kwargs["default"] if delete_default and "default" in kwargs: del kwargs["default"] try: parser.add_argument(*field_args, **kwargs) except ArgumentError: pass
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): data_path_parent = self.cfg.data task_cfg = task_cfg or self.cfg data_path_list = [ os.path.join(data_path_parent, path) for path in os.listdir(data_path_parent) ] # upgrade old task if isinstance(task_cfg, Namespace): if not hasattr(task_cfg, "autoregressive"): task_cfg.autoregressive = not task_cfg.criterion == "ctc" dataset_map = OrderedDict() datasets_lengths = [] for data_path in data_path_list: if getattr(task_cfg, "binarized_dataset", False): dataset_map[data_path] = BinarizedAudioDataset( data_path, split=split, sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate), max_sample_size=self.cfg.max_sample_size, min_sample_size=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu), **self._get_mask_precompute_kwargs(task_cfg), ) else: manifest_path = os.path.join(data_path, "{}.tsv".format(split)) dataset_map[data_path] = FileAudioDataset( manifest_path=manifest_path, sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate), max_sample_size=self.cfg.max_sample_size, min_sample_size=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu), **self._get_mask_precompute_kwargs(task_cfg), ) if self.cfg.tpu and task_cfg["mask_channel_prob"] == 0.0: logger.info( "Pretraining on TPUs may suffer convergence " "issues when training with `mask_channel_prob` value of " "0. You may want to set this to a low value close to 0.") if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") if os.path.exists(label_path): skipped_indices = getattr(dataset_map[data_path], "skipped_indices", set()) with open(label_path, "r") as f: labels = [ line for i, line in enumerate(f) if i not in skipped_indices ] assert len(labels) == len(dataset_map[data_path]), ( f"labels length ({len(labels)}) and dataset length " f"({len(dataset_map[data_path])}) do not match") process_label = LabelEncoder(self.target_dictionary) dataset_map[data_path] = AddTargetDataset( dataset_map[data_path], labels, pad=self.target_dictionary.pad(), eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, add_to_input=task_cfg.get("autoregressive", False), ) datasets_lengths.append( sum(dataset_map[data_path].sizes) / task_cfg.sample_rate / 3600) datasets_lengths = np.array(datasets_lengths) self.sample_probs = self._get_sample_prob(datasets_lengths) size_ratio = (self.sample_probs * datasets_lengths.sum()) / datasets_lengths for id, data_path in enumerate(data_path_list): logger.info( "Up/Down Sampling ratio by datasets: {} : {:.2f} to prob:{:.2f}".\ format(data_path.split('/')[-1], size_ratio[id],self.sample_probs[id]) ) self.datasets[split] = MultiCorpusSampledDataset( dataset_map, sampling_func=self.dataset_sampler) logger.info('{} {} examples'.format(split, len(self.datasets[split])))