def _pre_load_args(args): cfg_file_args = yaml_load_checking( load_from_config_path( flatten_string_list( getattr(args, flags_core.DEFAULT_CONFIG_FLAG.name)))) model_dirs = flatten_string_list(args.model_dir or cfg_file_args.get("model_dir", None)) hparams_set = args.hparams_set if hparams_set is None: hparams_set = cfg_file_args.get("hparams_set", None) predefined_parameters = get_hyper_parameters(hparams_set) formatted_parameters = {} if "model.class" in predefined_parameters: formatted_parameters["model.class"] = predefined_parameters.pop( "model.class") if "model" in predefined_parameters: formatted_parameters["model"] = predefined_parameters.pop("model") if "model.params" in predefined_parameters: formatted_parameters["model.params"] = predefined_parameters.pop( "model.params") if len(predefined_parameters) > 0: formatted_parameters["entry.params"] = predefined_parameters try: model_cfgs = ModelConfigs.load(model_dirs[0]) return deep_merge_dict( deep_merge_dict(model_cfgs, formatted_parameters), cfg_file_args) except Exception: return deep_merge_dict(formatted_parameters, cfg_file_args)
def _flatten_args(flag_list, from_args, backend="tf"): args = copy.deepcopy(from_args) flattened_args = {} for f in flag_list: if isinstance(f, Flag) and f.name in args: flattened_args[f.name] = args.pop(f.name) for f in flag_list: if isinstance(f, ModuleFlag): if f.cls_key in args: flattened_args[f.cls_key] = args.pop(f.cls_key) args.pop(f.name, None) elif f.name in args: flattened_args[f.cls_key] = args.pop(f.name) if f.cls_key in flattened_args and flattened_args[ f.cls_key] and args.get(f.params_key, None): if hasattr( REGISTRIES[backend][f.module_name][flattened_args[ f.cls_key]], "class_or_method_args"): for ff in REGISTRIES[backend][f.module_name][ flattened_args[f.cls_key]].class_or_method_args(): if isinstance(ff, Flag): if ff.name in args[ f. params_key] and ff.name not in flattened_args: flattened_args[ff.name] = args[ f.params_key].pop(ff.name) else: if ff.cls_key in args: flattened_args[ff.cls_key] = args.pop( ff.cls_key) args.pop(ff.name, None) elif ff.name in args: flattened_args[ff.cls_key] = args.pop(ff.name) elif ff.cls_key in args[f.params_key]: flattened_args[ff.cls_key] = args[ f.params_key].pop(ff.cls_key) elif ff.name in args[f.params_key]: flattened_args[ff.cls_key] = args[ f.params_key].pop(ff.name) if ff.params_key in args[f.params_key]: flattened_args[ ff.params_key] = deep_merge_dict( args[f.params_key][ff.params_key], flattened_args.get(ff.params_key, {})) else: flattened_args[f.params_key] = args.pop(f.params_key) args.pop(f.params_key, None) return deep_merge_dict(flattened_args, args)
def parse_flags(flag_list, arg_parser: argparse.ArgumentParser, args_preload_func=_args_preload_from_config_files): """ Parses flags from argument parser. Args: flag_list: A list of flags. arg_parser: The program argument parser. args_preload_func: A callable function for pre-loading arguments, maybe from config file, hyper parameter set. """ program_parsed_args, remaining_argv = arg_parser.parse_known_args() cfg_file_args = {} if args_preload_func is not None: cfg_file_args = args_preload_func(program_parsed_args) program_parsed_args = yaml_load_checking(program_parsed_args.__dict__) top_program_parsed_args = {} for f in flag_list: flag_key = f.name if isinstance(f, ModuleFlag): flag_key = f.cls_key top_program_parsed_args[f.params_key] = {} if program_parsed_args.get(f.params_key, None) is not None: top_program_parsed_args[f.params_key] = program_parsed_args[f.params_key] if f.params_key in cfg_file_args: top_program_parsed_args[f.params_key] = deep_merge_dict( cfg_file_args[f.params_key], top_program_parsed_args[f.params_key]) if program_parsed_args.get(flag_key, None) is not None: top_program_parsed_args[flag_key] = program_parsed_args[flag_key] elif flag_key in cfg_file_args: top_program_parsed_args[flag_key] = cfg_file_args[flag_key] else: top_program_parsed_args[flag_key] = f.default return top_program_parsed_args, remaining_argv
def extend_define_and_parse(flag_name, args, remaining_argv, backend="tf"): f = _DEFINED_FLAGS.get(flag_name, None) if f is None or not isinstance(f, ModuleFlag): return args if not hasattr(REGISTRIES[backend][f.module_name][args[f.cls_key]], "class_or_method_args"): return args arg_parser = argparse.ArgumentParser() for ff in REGISTRIES[backend][f.module_name][args[f.cls_key]].class_or_method_args(): if isinstance(ff, ModuleFlag): if args[f.params_key].get(ff.cls_key, None): this_cls = REGISTRIES[backend][ff.module_name][args[f.params_key][ff.cls_key]] if hasattr(this_cls, "class_or_method_args"): for fff in this_cls.class_or_method_args(): fff.define(arg_parser) parsed_args, remaining_argv = arg_parser.parse_known_args(remaining_argv) parsed_args = yaml_load_checking(parsed_args.__dict__) for ff in REGISTRIES[backend][f.module_name][args[f.cls_key]].class_or_method_args(): if isinstance(ff, ModuleFlag): if args[f.params_key].get(ff.cls_key, None): this_cls = REGISTRIES[backend][ff.module_name][args[f.params_key][ff.cls_key]] if hasattr(this_cls, "class_or_method_args"): if args[f.params_key].get(ff.params_key, None) is None: args[f.params_key][ff.params_key] = {} for fff in this_cls.class_or_method_args(): flag_key = fff.name if isinstance(fff, ModuleFlag): flag_key = fff.cls_key if parsed_args[flag_key] is not None: args[f.params_key][ff.params_key][flag_key] = parsed_args[flag_key] args.pop(flag_key, None) args.pop(fff.name, None) elif flag_key in args: args[f.params_key][ff.params_key][flag_key] = args.pop(flag_key) args.pop(fff.name, None) elif fff.name in args: args[f.params_key][ff.params_key][flag_key] = args.pop(fff.name) elif fff.name in args[f.params_key][ff.params_key]: if flag_name not in args[f.params_key][ff.params_key]: args[f.params_key][ff.params_key][flag_key] = args[f.params_key][ff.params_key].pop( fff.name) if isinstance(fff, ModuleFlag): args[f.params_key][ff.params_key][fff.params_key] = deep_merge_dict( args[f.params_key][ff.params_key].get(fff.params_key, {}) or {}, deep_merge_dict(args.get(fff.params_key, {}) or {}, parsed_args.get(fff.params_key, {}) or {})) return args, remaining_argv
def get_data_preprocess_fn(self, mode, data_status=compat.DataStatus.RAW, args=None) -> callable: """ Preprocess data sample according to this task. Args: args: A dict containing dataset arguments. mode: A ModeKeys indicating the running mode. data_status: The status of the data sample. Returns: A callable function to collate (process) a data sample. """ if args is None: args = self._args else: args = deep_merge_dict(self._args, args, local_overwrite=False) truncate_src = args.get("truncate_src", None) truncate_trg = args.get("truncate_trg", None) max_src_len = args.get("max_src_len", None) max_trg_len = args.get("max_trg_len", None) def _process_and_truncate(text, dp, trunc, max_len): if data_status != compat.DataStatus.PROJECTED: text = dp.process( text, is_processed=(data_status == compat.DataStatus.PROCESSED)) if mode == compat.ModeKeys.TRAIN and trunc and max_len: if isinstance(text, tf.Tensor): text = tf.cond( tf.less_equal(tf.size(text), max_len), lambda: text, lambda: tf.concat([text[:(max_len - 1)], text[-1:]], axis=0)) elif len(text) > max_len: text = text[:(max_len - 1)] + text[-1:] return text if mode == compat.ModeKeys.INFER: return lambda data: { "feature": _process_and_truncate(data["feature"], self._src_data_pipeline, truncate_src, max_src_len) } return lambda data: { "feature": _process_and_truncate(data["feature"], self._src_data_pipeline, truncate_src, max_src_len), "label": _process_and_truncate(data["label"], self._trg_data_pipeline, truncate_trg, max_trg_len) }
def build_x(args, *extra_args, **kwargs): params_ = {} if isinstance(args, dict): cls_ = (args.get("class", None) or args.get(f"{registry_name}.class", None) or args.get(f"{registry_name}", None)) params_ = (args.get("params", None) or args.get( "{}.params".format(registry_name), {})) or {} else: cls_ = args if cls_ is None: return None if isinstance(cls_, str): if cls_.lower() == "none": return None if cls_ not in REGISTRIES[backend][registry_name]: raise ValueError("Not registered class name: {}.".format(cls_)) cls_ = REGISTRIES[backend][registry_name][cls_] builder = cls_ elif callable(cls_): builder = cls_ else: raise ValueError("Not supported type: {} for builder.".format( type(cls_))) if create_fn is not None: assert hasattr(builder, create_fn), "{} has no {} for creation.".format( cls_, create_fn) builder = getattr(builder, create_fn) if kwargs is None: kwargs = {} assert isinstance( params_, dict), f"Not supported type: {type(params_)} for params" if hasattr(cls_, "class_or_method_args"): for f in cls_.class_or_method_args(): if isinstance(f, Flag): if f.name in kwargs: params_[f.name] = kwargs.pop(f.name) elif f.name not in params_: params_[f.name] = f.default elif isinstance(f, ModuleFlag) and f.cls_key not in params_: params_[f.cls_key] = f.default if isinstance(f, ModuleFlag) and f.params_key not in params_: params_[f.params_key] = {} _verbose_creation(cls_, params_, *extra_args, **kwargs) return builder(params_, *extra_args, **kwargs) params_ = deep_merge_dict(params_, kwargs, merge_only_exist=False) _verbose_creation(cls_, {}, *extra_args, **params_) return builder(*extra_args, **params_)
def get_data_preprocess_fn(self, mode, data_status=compat.DataStatus.RAW, args=None) -> callable: """ Preprocess data sample according to this task. Args: args: A dict containing dataset arguments. mode: A ModeKeys indicating the running mode. data_status: The status of the data sample. Returns: A callable function to collate (process) a data sample. """ if args is None: args = self._args else: args = deep_merge_dict(self._args, args, local_overwrite=False) truncate = args.get("truncate", None) max_len = args.get("max_len", None) def _process_and_truncate(data): text = data["tokens"] if data_status != compat.DataStatus.PROJECTED: text = self._data_pipeline.process( text, is_processed=(data_status == compat.DataStatus.PROCESSED)) if mode == compat.ModeKeys.TRAIN and truncate and max_len: if compat.is_tf_tensor(text): text = tf.cond( tf.less_equal(tf.size(text), max_len), lambda: text, lambda: tf.concat([text[:(max_len - 1)], text[-1:]], axis=0)) elif len(text) > max_len: text = text[:(max_len - 1)] + text[-1:] return {"tokens": text} return _process_and_truncate
def create_and_batch_tfds(self, ds: Dataset, mode, args=None, num_replicas_in_sync=1) -> tf.data.Dataset: """ Creates a dataset according to the `mode`. Args: args: A dict containing dataset arguments. ds: A neurst.data.datasets.Dataset object. mode: A ModeKeys indicating the running mode. num_replicas_in_sync: The number of GPUs or other workers. We will generate global batches, and each global batch is equally divisible by number of replicas. Returns: A tf.data.Dataset or a INFER_DATA tuple. """ if args is None: args = self._args else: args = deep_merge_dict(self._args, args) src_eos = tf.constant(self._src_data_pipeline.meta["eos_id"], dtype=tf.int64) trg_eos = tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64) assert isinstance(ds, AbstractParallelDataset), ( "The dataset for SeqToSeq task must inherit AbstractParallelDataset.") dataset = ds.build(map_func=self.get_data_preprocess_fn(mode, ds.status, args), map_output_dtypes=self.inputs_signature(mode)[0], auto_shard=(mode == compat.ModeKeys.TRAIN), shuffle=(mode == compat.ModeKeys.TRAIN)) if mode == compat.ModeKeys.INFER: logging.info("Creating test dataset.") test_dataset = dataset_utils.batch_sequential_dataset( dataset=dataset.cache(), batch_size=args["batch_size"], padding_values={"feature": src_eos}, num_replicas_in_sync=num_replicas_in_sync, drop_remainder=False) return test_dataset elif mode == compat.ModeKeys.EVAL: logging.info("Creating evaluation dataset.") return dataset_utils.batch_sequential_dataset( dataset.cache(), batch_size=args["batch_size"], padding_values={"feature": src_eos, "label": trg_eos}, num_replicas_in_sync=num_replicas_in_sync, drop_remainder=False) else: logging.info("Creating training dataset.") if args["cache_dataset"]: dataset = dataset.cache() dataset = dataset_utils.batch_sequential_dataset( dataset, padding_values={"feature": src_eos, "label": trg_eos}, batch_size=args["batch_size"], batch_size_per_gpu=args["batch_size_per_gpu"], batch_by_tokens=args["batch_by_tokens"], shuffer_buffer=args["shuffle_buffer"], data_max_lengths={"feature": args["max_src_len"], "label": args["max_trg_len"]}, drop_remainder=True, num_replicas_in_sync=num_replicas_in_sync) return dataset
def create_and_batch_tfds(self, ds: Dataset, mode, args=None, num_replicas_in_sync=1) -> tf.data.Dataset: """ Creates a dataset according to the `mode`. Args: args: A dict containing dataset arguments. ds: A neurst.data.datasets.Dataset object. mode: A ModeKeys indicating the running mode. num_replicas_in_sync: The number of GPUs or other workers. We will generate global batches, and each global batch is equally divisible by number of replicas. Returns: A tf.data.Dataset. """ if args is None: args = self._args else: args = deep_merge_dict(self._args, args, local_overwrite=False) pad = tf.constant(self._data_pipeline.meta["pad_id"], dtype=tf.int64) dataset = ds.build(map_func=self.get_data_preprocess_fn( mode, ds.status, args), map_output_dtypes=self.inputs_signature(mode)[0], auto_shard=(mode == compat.ModeKeys.TRAIN), shuffle=(mode == compat.ModeKeys.TRAIN)) if mode == compat.ModeKeys.INFER: raise NotImplementedError # logging.info("Creating test dataset.") # return dataset.cache().padded_batch( # dataset_utils.adjust_batch_size(args["batch_size"], # num_replicas_in_sync=num_replicas_in_sync), # padded_shapes={"tokens": [None]}, # padding_values={"tokens": pad}, # drop_remainder=False) elif mode == compat.ModeKeys.EVAL: logging.info("Creating evaluation dataset.") return dataset.cache().padded_batch( dataset_utils.adjust_batch_size( args["batch_size"], num_replicas_in_sync=num_replicas_in_sync), padded_shapes={"tokens": [None]}, padding_values={"tokens": pad}, drop_remainder=False) else: logging.info("Creating training dataset.") level = args.get("gpu_efficient_level", None) logging.info( f"Creating training dataset with GPU efficient level={level}.") dataset = ds.build( map_func=self.get_data_preprocess_fn(mode, ds.status, args), map_output_dtypes=self.inputs_signature(mode)[0], auto_shard=True, shuffle=True) dataset = dataset_utils.clean_dataset_by_length( dataset, {"tokens": args["max_len"]}) if args["cache_dataset"]: dataset = dataset.cache() if args["shuffle_buffer"]: dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"]) padding_values = { "tokens": tf.constant(self._data_pipeline.meta["pad_id"], dtype=tf.int64) } if args["max_len"] is None: raise RuntimeError("Must provide `max_len` for training.") max_len = minimal_multiple(args["max_len"], EFFICIENT_MULTIPLIER[level]) batch_size = dataset_utils.adjust_batch_size( args["batch_size"], args["batch_size_per_gpu"], num_replicas_in_sync=num_replicas_in_sync, verbose=False) if level == GPU_EFFICIENT_LEVEL.LEVEL5: # static batch _batch_size = batch_size if args["batch_by_tokens"]: _batch_size = _batch_size // max_len logging.info( f"Batching dataset with fixed shape: batch={_batch_size} x {max_len})." ) return dataset.padded_batch( _batch_size // num_replicas_in_sync * num_replicas_in_sync, padded_shapes={"tokens": [max_len]}, padding_values=padding_values, drop_remainder=True) else: bucket_boundaries = [ EFFICIENT_MULTIPLIER[level] * i for i in range(1, max_len // EFFICIENT_MULTIPLIER[level] + 1) ] if bucket_boundaries[-1] < max_len: bucket_boundaries.append( minimal_multiple(bucket_boundaries[-1] + 1, EFFICIENT_MULTIPLIER[level])) bucket_boundaries = {"tokens": bucket_boundaries} bucket_batch_sizes = dataset_utils.adjust_batch_size( batch_size, bucket_boundaries=bucket_boundaries if args["batch_by_tokens"] else None, boundaries_reduce_to_length_fn=lambda x: max( tf.nest.flatten(x)), num_replicas_in_sync=num_replicas_in_sync) if level != GPU_EFFICIENT_LEVEL.LEVEL0: if isinstance(bucket_batch_sizes, list): bucket_batch_sizes = [ int( maximum_lower_multiple( x // num_replicas_in_sync, EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync) for x in bucket_batch_sizes ] else: bucket_batch_sizes = int( maximum_lower_multiple( bucket_batch_sizes // num_replicas_in_sync, EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync) return dataset_utils.batch_examples_by_token( dataset, bucket_boundaries=bucket_boundaries, bucket_batch_sizes=bucket_batch_sizes, padding_values=padding_values, example_length_func=lambda x: {k: tf.size(v) for k, v in x.items()})
def create_and_batch_tfds(self, ds: Dataset, mode, args=None, num_replicas_in_sync=1) -> tf.data.Dataset: """ Creates a dataset according to the `mode`. Args: args: A dict containing dataset arguments. ds: A neurst.data.datasets.Dataset object. mode: A ModeKeys indicating the running mode. num_replicas_in_sync: The number of GPUs or other workers. We will generate global batches, and each global batch is equally divisible by number of replicas. Returns: A tf.data.Dataset. """ if args is None: args = self._args else: args = deep_merge_dict(self._args, args, local_overwrite=False) src_eos = tf.constant(self._src_data_pipeline.meta["eos_id"], dtype=tf.int64) trg_eos = tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64) assert isinstance(ds, AbstractParallelDataset), ( "The dataset for SeqToSeq task must inherit AbstractParallelDataset." ) dataset = ds.build(map_func=self.get_data_preprocess_fn( mode, ds.status, args), map_output_dtypes=self.inputs_signature(mode)[0], auto_shard=(mode == compat.ModeKeys.TRAIN), shuffle=(mode == compat.ModeKeys.TRAIN)) if mode == compat.ModeKeys.INFER: logging.info("Creating test dataset.") return dataset.cache().padded_batch( dataset_utils.adjust_batch_size( args["batch_size"], num_replicas_in_sync=num_replicas_in_sync), padded_shapes={"feature": [None]}, padding_values={"feature": src_eos}, drop_remainder=False) elif mode == compat.ModeKeys.EVAL: logging.info("Creating evaluation dataset.") return dataset.cache().padded_batch( dataset_utils.adjust_batch_size( args["batch_size"], num_replicas_in_sync=num_replicas_in_sync), padded_shapes={ "feature": [None], "label": [None] }, padding_values={ "feature": src_eos, "label": trg_eos }, drop_remainder=False) else: logging.info("Creating training dataset.") dataset = dataset_utils.clean_dataset_by_length( dataset, { "feature": args["max_src_len"], "label": args["max_trg_len"] }) if args["cache_dataset"]: dataset = dataset.cache() if args["shuffle_buffer"]: dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"]) padding_values = {"feature": src_eos, "label": trg_eos} if args["max_src_len"] is None: raise RuntimeError("Must provide `max_src_len` for training.") if args["max_trg_len"] is None: raise RuntimeError("Must provide `max_trg_len` for training.") src_bucket_boundaries, trg_bucket_boundaries = dataset_utils.associated_bucket_boundaries( dataset_utils.create_batch_bucket_boundaries( args["max_src_len"]), dataset_utils.create_batch_bucket_boundaries( args["max_trg_len"])) bucket_boundaries = { "feature": src_bucket_boundaries, "label": trg_bucket_boundaries } bucket_batch_sizes = dataset_utils.adjust_batch_size( args["batch_size"], args["batch_size_per_gpu"], bucket_boundaries=bucket_boundaries if args["batch_by_tokens"] else None, boundaries_reduce_to_length_fn=lambda x: max(tf.nest.flatten(x) ), num_replicas_in_sync=num_replicas_in_sync) return dataset_utils.batch_examples_by_token( dataset, bucket_boundaries=bucket_boundaries, bucket_batch_sizes=bucket_batch_sizes, padding_values=padding_values, example_length_func=lambda x: {k: tf.size(v) for k, v in x.items()})
def create_and_batch_tfds(self, ds: Dataset, mode, args=None, num_replicas_in_sync=1) -> tf.data.Dataset: """ Creates a dataset according to the `mode`. Args: args: A dict containing dataset arguments. ds: A neurst.data.datasets.Dataset object. mode: A ModeKeys indicating the running mode. num_replicas_in_sync: The number of GPUs or other workers. We will generate global batches, and each global batch is equally divisible by number of replicas. Returns: A tf.data.Dataset. """ if args is None: args = self._args else: args = deep_merge_dict(self._args, args, local_overwrite=False) float_zero = tf.constant(0, dtype=tf.float32) int_zero = tf.constant(0, dtype=tf.int64) trg_eos = tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64) dataset = ds.build(map_func=self.get_data_preprocess_fn( mode, ds.status, args), map_output_dtypes=self.inputs_signature(mode)[0], auto_shard=(mode == compat.ModeKeys.TRAIN), shuffle=(mode == compat.ModeKeys.TRAIN)) if mode == compat.ModeKeys.INFER: logging.info("Creating test dataset.") return dataset.cache().padded_batch( dataset_utils.adjust_batch_size( args["batch_size"], num_replicas_in_sync=num_replicas_in_sync), padded_shapes={ "audio": [None], "audio_length": [] }, padding_values={ "audio": float_zero, "audio_length": int_zero }, drop_remainder=False) elif mode == compat.ModeKeys.EVAL: logging.info("Creating evaluation dataset.") return dataset.cache().padded_batch( dataset_utils.adjust_batch_size( args["batch_size"], num_replicas_in_sync=num_replicas_in_sync), padded_shapes={ "audio": [None], "audio_length": [], "transcript": [None] }, padding_values={ "audio": float_zero, "audio_length": int_zero, "transcript": trg_eos }, drop_remainder=False) else: logging.info("Creating training dataset.") dataset = dataset_utils.clean_dataset_by_length( dataset, { "audio": args["max_src_len"] * self._audio_feature_dim * self._audio_feature_channels, "audio_length": -1, "transcript": args["max_trg_len"] }) if args["cache_dataset"]: dataset = dataset.cache() if args["shuffle_buffer"]: dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"]) padding_values = { "audio": float_zero, "audio_length": int_zero, "transcript": trg_eos } if args["max_src_len"] is None: raise RuntimeError( "`max_src_len` for SpeechToText task must be provided.") if args["max_trg_len"] is None: raise RuntimeError( "`max_trg_len` for SpeechToText task must be provided.") max_src_len = args["max_src_len"] max_trg_len = minimal_multiple(args["max_trg_len"], 8) audio_bucket_boundaries = create_audio_bucket_boundaries( max_src_len, args["min_src_bucket_boundary"]) audio_bucket_boundaries[-1] = minimal_multiple( audio_bucket_boundaries[-1], 8) batch_size = dataset_utils.adjust_batch_size( args["batch_size"], args["batch_size_per_gpu"], num_replicas_in_sync=num_replicas_in_sync, verbose=False) batch_size_per_gpu = batch_size // num_replicas_in_sync assert batch_size_per_gpu > max_src_len, ( f"batch size per gpu({batch_size_per_gpu} must be greater than " f"`max_src_len`={max_src_len}") if args["disable_batch_efficiency"]: bucket_batch_sizes = [ int(batch_size_per_gpu // bound * num_replicas_in_sync) for bound in audio_bucket_boundaries ] else: bucket_batch_sizes = [ int( minimal_multiple(batch_size_per_gpu // bound, 8) * num_replicas_in_sync) for bound in audio_bucket_boundaries ] frame_transcript_ratio = args[ "experimental_frame_transcript_ratio"] if frame_transcript_ratio is None: logging.info( "WARNING: we recommend one to pre-scan the dataset and estimate the ratio: " "frame length / transcript length.") else: trans_bucket_boundaries = [ int(bound / (frame_transcript_ratio + i * (max_src_len / max_trg_len - frame_transcript_ratio) / len(audio_bucket_boundaries))) for i, bound in enumerate(audio_bucket_boundaries) ] trans_bucket_boundaries = [ minimal_multiple(min(i, max_trg_len), 8) for i in trans_bucket_boundaries ] num_buckets = len(trans_bucket_boundaries) true_trans_bucket_boundaries = [] num_input_shapes = 0 for idx, (batc, bound, tbound) in enumerate( zip(bucket_batch_sizes, audio_bucket_boundaries, trans_bucket_boundaries)): max_trans_len = [ tbound, trans_bucket_boundaries[min( idx + 1, len(bucket_batch_sizes) - 1)] ] num_input_shapes += len(set(max_trans_len)) true_trans_bucket_boundaries.append(max_trans_len) logging.info( f"There are {num_input_shapes} input shapes to be compiled:" ) for idx, (batc, bound, tbound) in enumerate( zip(bucket_batch_sizes, audio_bucket_boundaries, true_trans_bucket_boundaries)): logging.info(f" - batch={batc}, maximum-frames={bound}, " f"maximum-transcript-length={set(tbound)}") true_trans_bucket_boundaries = tf.constant( true_trans_bucket_boundaries, dtype=tf.int32) true_audio_bucket_boundaries = tf.transpose( tf.constant([audio_bucket_boundaries] * 2, dtype=tf.int32)) bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64) audio_bucket_boundaries = tf.constant(audio_bucket_boundaries, dtype=tf.int32) def example_to_bucket_id(examples): """Return int64 bucket id for this example, calculated based on length.""" if frame_transcript_ratio is None: conditions_c = tf.less_equal( tf.cast(examples["audio_length"], tf.int32), audio_bucket_boundaries) return tf.reduce_min(tf.where(conditions_c)) conditions_c = tf.logical_and( tf.less_equal(tf.cast(examples["audio_length"], tf.int32), true_audio_bucket_boundaries), tf.less_equal(tf.size(examples["transcript"]), true_trans_bucket_boundaries)) minimum_match = tf.where(conditions_c)[0] return minimum_match[0] * num_buckets + minimum_match[1] def window_size_fn(bucket_id): """Return number of examples to be grouped when given a bucket id.""" if frame_transcript_ratio is None: return bucket_batch_sizes[bucket_id] return bucket_batch_sizes[bucket_id // num_buckets] def batching_fn(bucket_id, grouped_dataset): """Batch and add padding to a dataset of elements with similar lengths.""" bucket_batch_size = window_size_fn(bucket_id) # Batch the dataset and add padding so that all input sequences in the # examples have the same length, and all target sequences have the same # lengths as well. Resulting lengths of inputs and targets can differ. return grouped_dataset.padded_batch( bucket_batch_size, padded_shapes={ "audio": ([(audio_bucket_boundaries[bucket_id] if frame_transcript_ratio is None else audio_bucket_boundaries[bucket_id // num_buckets]) * self._audio_feature_dim * self._audio_feature_channels]), "audio_length": [], "transcript": ([None] if frame_transcript_ratio is None else [ true_trans_bucket_boundaries[ bucket_id // num_buckets][bucket_id % num_buckets] ]) }, padding_values=padding_values, drop_remainder=True) return dataset.apply( tf.data.experimental.group_by_window( key_func=example_to_bucket_id, reduce_func=batching_fn, window_size=None, window_size_func=window_size_fn))
def get_data_preprocess_fn(self, mode, data_status, args=None) -> callable: """ Preprocess data sample according to this task. Args: args: A dict containing dataset arguments. mode: A ModeKeys indicating the running mode. data_status: The status of the data sample. Returns: A callable function to collate (process) a data sample. """ if args is None: args = self._args else: args = deep_merge_dict(self._args, args, local_overwrite=False) trunc_audio = args.get("truncate_src", None) max_audio_len = args.get("max_src_len", None) trunc_trg = args.get("truncate_trg", None) max_trg_len = args.get("max_trg_len", None) if data_status["audio"] != compat.DataStatus.PROJECTED: raise RuntimeError( "We recommend one to preprocess the audio in advance.") def _process_audio(audio): if trunc_audio and max_audio_len: audio = audio[:max_audio_len * self._audio_feature_dim * self._audio_feature_channels] if self._specaug is not None: audio = tf.reshape(audio, [ -1, self._audio_feature_dim * self._audio_feature_channels ]) audio = tf.reshape(self._specaug(audio), [-1]) return audio def _process_and_truncate_text(text): if data_status["transcript"] == compat.DataStatus.RAW: text = self._trg_data_pipeline.process(text, is_processed=False) else: assert data_status["transcript"] == compat.DataStatus.PROJECTED if mode == compat.ModeKeys.TRAIN and trunc_trg and max_trg_len: if isinstance(text, tf.Tensor): text = tf.cond( tf.less_equal(tf.size(text), max_trg_len), lambda: text, lambda: tf.concat( [text[:(max_trg_len - 1)], text[-1:]], axis=0)) else: if len(text) > max_trg_len: text = text[:(max_trg_len - 1)] + text[-1:] return text def data_proc(data, with_label): feature = _process_audio(data["audio"]) ret = { "audio": feature, "audio_length": tf.cast( (tf.shape(feature)[0] if isinstance(feature, tf.Tensor) else feature.shape[0]) // self._audio_feature_dim // self._audio_feature_channels, dtype=tf.int64) } if with_label: ret["transcript"] = tf.convert_to_tensor( _process_and_truncate_text(data["transcript"]), tf.int64) return ret if mode == compat.ModeKeys.INFER: return lambda data: data_proc(data, False) return lambda data: data_proc(data, True)
def intelligent_parse_flags(flag_list, arg_parser: argparse.ArgumentParser, args_preload_func=_args_preload_from_config_files, backend="tf"): """ Parses flags from argument parser. Args: flag_list: A list of flags. arg_parser: The program argument parser. args_preload_func: A callable function for pre-loading arguments, maybe from config file, hyper parameter set. backend: The DL backend. """ program_parsed_args, remaining_argv = arg_parser.parse_known_args() cfg_file_args = {} if args_preload_func is not None: cfg_file_args = args_preload_func(program_parsed_args) top_program_parsed_args = _flatten_args(flag_list, yaml_load_checking(program_parsed_args.__dict__)) for f in flag_list: if isinstance(f, ModuleFlag): if f.cls_key in top_program_parsed_args and top_program_parsed_args[f.cls_key]: cfg_file_args[f.cls_key] = top_program_parsed_args[f.cls_key] cfg_file_args = _flatten_args(flag_list, cfg_file_args) for f in flag_list: if isinstance(f, Flag): if top_program_parsed_args[f.name] is None: top_program_parsed_args[f.name] = cfg_file_args.get(f.name, None) cfg_file_args.pop(f.name, None) else: submodule_cls = (top_program_parsed_args.get(f.cls_key, None) or cfg_file_args.get(f.cls_key, None)) cfg_file_args.pop(f.cls_key, None) if submodule_cls is None: continue top_program_parsed_args[f.cls_key] = submodule_cls if top_program_parsed_args.get(f.params_key, None) is None: top_program_parsed_args[f.params_key] = {} module_arg_parser = get_argparser(f.module_name, submodule_cls) module_parsed_args, remaining_argv = module_arg_parser.parse_known_args(remaining_argv) module_parsed_args = yaml_load_checking(module_parsed_args.__dict__) if hasattr(REGISTRIES[backend][f.module_name][submodule_cls], "class_or_method_args"): key_cfg_file_args = _flatten_args( REGISTRIES[backend][f.module_name][submodule_cls].class_or_method_args(), cfg_file_args) for inner_f in REGISTRIES[backend][f.module_name][submodule_cls].class_or_method_args(): flag_key = inner_f.name if isinstance(inner_f, ModuleFlag): flag_key = inner_f.cls_key cfg_file_args.pop(flag_key, None) if module_parsed_args[flag_key] is not None: top_program_parsed_args[f.params_key][flag_key] = module_parsed_args[flag_key] top_program_parsed_args.pop(flag_key, None) key_cfg_file_args.pop(flag_key, None) cfg_file_args.pop(flag_key, None) elif flag_key in top_program_parsed_args: top_program_parsed_args[f.params_key][flag_key] = top_program_parsed_args.pop(flag_key) key_cfg_file_args.pop(flag_key, None) cfg_file_args.pop(flag_key, None) elif flag_key in key_cfg_file_args: top_program_parsed_args[f.params_key][flag_key] = key_cfg_file_args.pop(flag_key) cfg_file_args.pop(flag_key, None) if isinstance(inner_f, ModuleFlag): top_program_parsed_args[f.params_key][inner_f.params_key] = deep_merge_dict( cfg_file_args.pop(inner_f.params_key, {}) or {}, deep_merge_dict(top_program_parsed_args[f.params_key].pop(inner_f.params_key, {}) or {}, deep_merge_dict(top_program_parsed_args.pop(inner_f.params_key, {}) or {}, module_parsed_args.pop(inner_f.params_key, {}) or {}))) top_program_parsed_args = deep_merge_dict(cfg_file_args, top_program_parsed_args) for f in flag_list: if isinstance(f, Flag): if f.name not in top_program_parsed_args or top_program_parsed_args[f.name] is None: top_program_parsed_args[f.name] = f.default return top_program_parsed_args, remaining_argv
def create_and_batch_tfds(self, ds, mode, args=None, num_replicas_in_sync=1): """ With efficient level for training. """ if mode in [compat.ModeKeys.INFER, compat.ModeKeys.EVAL]: return super(Translation, self).create_and_batch_tfds( ds, mode, args, num_replicas_in_sync) if args is None: args = self._args else: args = deep_merge_dict(self._args, args, local_overwrite=False) level = args.get("gpu_efficient_level", None) auto_scale_batch = args.get("auto_scaling_batch_size", None) logging.info(f"Creating training dataset with GPU efficient level={level}.") dataset = ds.build(map_func=self.get_data_preprocess_fn(mode, ds.status, args), map_output_dtypes=self.inputs_signature(mode)[0], auto_shard=True, shuffle=True) dataset = dataset_utils.clean_dataset_by_length( dataset, {"feature": args["max_src_len"], "label": args["max_trg_len"]}) if args["cache_dataset"]: dataset = dataset.cache() if args["shuffle_buffer"]: dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"]) padding_values = {"feature": tf.constant(self._src_data_pipeline.meta["eos_id"], dtype=tf.int64), "label": tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64)} if args["max_src_len"] is None: raise RuntimeError("Must provide `max_src_len` for training.") if args["max_trg_len"] is None: raise RuntimeError("Must provide `max_trg_len` for training.") max_src_len = minimal_multiple(args["max_src_len"], EFFICIENT_MULTIPLIER[level]) max_trg_len = minimal_multiple(args["max_trg_len"], EFFICIENT_MULTIPLIER[level]) max_len = max(max_src_len, max_trg_len) batch_size = dataset_utils.adjust_batch_size(args["batch_size"], args["batch_size_per_gpu"], num_replicas_in_sync=num_replicas_in_sync, verbose=False) if auto_scale_batch: batch_size = _auto_scale_batch_size(batch_size, level) logging.info(f"Auto scaling batch size to {batch_size}.") if level == GPU_EFFICIENT_LEVEL.LEVEL5: # static batch _batch_size = batch_size if args["batch_by_tokens"]: _batch_size = _batch_size // max_len logging.info("Batching dataset with fixed shape: " f"batch={_batch_size} x (feature={max_src_len}, label={max_trg_len}).") return dataset.padded_batch( _batch_size // num_replicas_in_sync * num_replicas_in_sync, padded_shapes={"feature": [max_src_len], "label": [max_trg_len]}, drop_remainder=True, padding_values=padding_values) else: src_bucket_boundaries = [EFFICIENT_MULTIPLIER[level] * i for i in range(1, max_src_len // EFFICIENT_MULTIPLIER[level] + 1)] if src_bucket_boundaries[-1] < max_src_len: src_bucket_boundaries.append(minimal_multiple(src_bucket_boundaries[-1] + 1, EFFICIENT_MULTIPLIER[level])) trg_bucket_boundaries = [EFFICIENT_MULTIPLIER[level] * i for i in range(1, max_trg_len // EFFICIENT_MULTIPLIER[level] + 1)] if trg_bucket_boundaries[-1] < max_trg_len: trg_bucket_boundaries.append(minimal_multiple(trg_bucket_boundaries[-1] + 1, EFFICIENT_MULTIPLIER[level])) src_bucket_boundaries, trg_bucket_boundaries = dataset_utils.associated_bucket_boundaries( src_bucket_boundaries, trg_bucket_boundaries) bucket_boundaries = { "feature": src_bucket_boundaries, "label": trg_bucket_boundaries } bucket_batch_sizes = dataset_utils.adjust_batch_size( batch_size, bucket_boundaries=bucket_boundaries if args["batch_by_tokens"] else None, boundaries_reduce_to_length_fn=lambda x: max(tf.nest.flatten(x)), num_replicas_in_sync=num_replicas_in_sync) if level != GPU_EFFICIENT_LEVEL.LEVEL0: if isinstance(bucket_batch_sizes, list): bucket_batch_sizes = [ int(maximum_lower_multiple(x // num_replicas_in_sync, EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync) for x in bucket_batch_sizes] else: bucket_batch_sizes = int(maximum_lower_multiple( bucket_batch_sizes // num_replicas_in_sync, EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync) return dataset_utils.batch_examples_by_token( dataset, bucket_boundaries=bucket_boundaries, bucket_batch_sizes=bucket_batch_sizes, padding_values=padding_values, example_length_func=lambda x: {k: tf.size(v) for k, v in x.items()} )
def create_and_batch_tfds(self, ds, mode, args=None, num_replicas_in_sync=1): """ With efficient level for training. """ if args is None: args = self._args else: args = deep_merge_dict(self._args, args) level = args.get("gpu_efficient_level", None) auto_scale_batch = args.get("auto_scaling_batch_size", None) if (mode in [compat.ModeKeys.INFER, compat.ModeKeys.EVAL] or level is None or level == GPU_EFFICIENT_LEVEL.LEVEL0): return super(Translation, self).create_and_batch_tfds(ds, mode, args, num_replicas_in_sync) padding_values = { "feature": tf.constant(self._src_data_pipeline.meta["eos_id"], dtype=tf.int64), "label": tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64) } dataset = ds.build(auto_shard=True, map_func=self.get_data_preprocess_fn( mode, ds.status, args), map_output_dtypes=self.inputs_signature(mode)[0]) max_src_len = args["max_src_len"] max_trg_len = args["max_trg_len"] batch_by_tokens = args["batch_by_tokens"] assert max_src_len, "Must provide `max_src_len` when `gpu_efficient_level` > 0" assert max_trg_len, "Must provide `max_trg_len` when `gpu_efficient_level` > 0" logging.info( f"Creating training dataset with `gpu_efficient_level`={level}.") dataset = clean_dataset_by_length(dataset, { "feature": max_src_len, "label": max_trg_len }) if args["cache_dataset"]: dataset = dataset.cache() if args["shuffle_buffer"]: dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"]) batch_size_per_gpu = args["batch_size_per_gpu"] batch_size = args["batch_size"] if batch_size_per_gpu is None: batch_size_per_gpu = batch_size // num_replicas_in_sync if batch_by_tokens: assert batch_size_per_gpu > max(max_src_len, max_trg_len), ( f"batch size per gpu({batch_size_per_gpu} must be greater than " f"both `max_src_len`{max_src_len} and `max_trg_len`{max_trg_len}" ) if auto_scale_batch: new_batch_size_per_gpu = _auto_scale_batch_size( batch_size_per_gpu, level) logging.info( f"Auto scaling `batch_size_per_gpu` from {batch_size_per_gpu} " f"to {new_batch_size_per_gpu}") batch_size_per_gpu = new_batch_size_per_gpu max_src_len = minimal_multiple(max_src_len, EFFICIENT_MULTIPLIER[level]) max_trg_len = minimal_multiple(max_trg_len, EFFICIENT_MULTIPLIER[level]) max_len = max(max_src_len, max_trg_len) if level == GPU_EFFICIENT_LEVEL.LEVEL5: # static batch if batch_by_tokens: batch_size_per_gpu = batch_size_per_gpu // max_len return dataset.padded_batch(int( minimal_multiple(batch_size_per_gpu, EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync), padded_shapes={ "feature": [max_src_len], "label": [max_trg_len] }, drop_remainder=True, padding_values=padding_values) else: bucket_boundaries = [ EFFICIENT_MULTIPLIER[level] * i for i in range(1, max_len // EFFICIENT_MULTIPLIER[level] + 1) ] if bucket_boundaries[-1] < max_len: bucket_boundaries.append( minimal_multiple(bucket_boundaries[-1] + 1, EFFICIENT_MULTIPLIER[level])) buckets_min = [0] + bucket_boundaries[:-1] if batch_by_tokens: bucket_batch_sizes = [ int( minimal_multiple(batch_size_per_gpu // bound, EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync) for bound in bucket_boundaries ] else: bucket_batch_sizes = [ int( minimal_multiple(batch_size_per_gpu, EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync) ] * len(bucket_boundaries) logging.info( f"There are {len(bucket_batch_sizes)} input shapes to be compiled:" ) for batc, bound in zip(bucket_batch_sizes, bucket_boundaries): logging.info(f" - batch={batc}, maximum-length={bound}") bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64) bucket_boundaries = tf.constant(bucket_boundaries, dtype=tf.int32) def example_to_bucket_id(examples): """Return int64 bucket id for this example, calculated based on length.""" seq_length = tf.cast( tf.maximum(tf.size(examples["feature"]), tf.size(examples["label"])), tf.int32) conditions_c = tf.logical_and( tf.less(buckets_min, seq_length), tf.less_equal(seq_length, bucket_boundaries)) bucket_id = tf.reduce_min(tf.where(conditions_c)) return bucket_id def window_size_fn(bucket_id): """Return number of examples to be grouped when given a bucket id.""" return bucket_batch_sizes[bucket_id] def batching_fn(bucket_id, grouped_dataset): """Batch and add padding to a dataset of elements with similar lengths.""" bucket_batch_size = window_size_fn(bucket_id) # Batch the dataset and add padding so that all input sequences in the # examples have the same length, and all target sequences have the same # lengths as well. Resulting lengths of inputs and targets can differ. return grouped_dataset.padded_batch( bucket_batch_size, padded_shapes={ "feature": [bucket_boundaries[bucket_id]], "label": [bucket_boundaries[bucket_id]] }, padding_values=padding_values, drop_remainder=True) return dataset.apply( tf.data.experimental.group_by_window( key_func=example_to_bucket_id, reduce_func=batching_fn, window_size=None, window_size_func=window_size_fn))