Пример #1
0
def _pre_load_args(args):
    cfg_file_args = yaml_load_checking(
        load_from_config_path(
            flatten_string_list(
                getattr(args, flags_core.DEFAULT_CONFIG_FLAG.name))))
    model_dirs = flatten_string_list(args.model_dir
                                     or cfg_file_args.get("model_dir", None))
    hparams_set = args.hparams_set
    if hparams_set is None:
        hparams_set = cfg_file_args.get("hparams_set", None)
    predefined_parameters = get_hyper_parameters(hparams_set)
    formatted_parameters = {}
    if "model.class" in predefined_parameters:
        formatted_parameters["model.class"] = predefined_parameters.pop(
            "model.class")
    if "model" in predefined_parameters:
        formatted_parameters["model"] = predefined_parameters.pop("model")
    if "model.params" in predefined_parameters:
        formatted_parameters["model.params"] = predefined_parameters.pop(
            "model.params")
    if len(predefined_parameters) > 0:
        formatted_parameters["entry.params"] = predefined_parameters

    try:
        model_cfgs = ModelConfigs.load(model_dirs[0])
        return deep_merge_dict(
            deep_merge_dict(model_cfgs, formatted_parameters), cfg_file_args)
    except Exception:
        return deep_merge_dict(formatted_parameters, cfg_file_args)
Пример #2
0
def _flatten_args(flag_list, from_args, backend="tf"):
    args = copy.deepcopy(from_args)
    flattened_args = {}
    for f in flag_list:
        if isinstance(f, Flag) and f.name in args:
            flattened_args[f.name] = args.pop(f.name)
    for f in flag_list:
        if isinstance(f, ModuleFlag):
            if f.cls_key in args:
                flattened_args[f.cls_key] = args.pop(f.cls_key)
                args.pop(f.name, None)
            elif f.name in args:
                flattened_args[f.cls_key] = args.pop(f.name)
            if f.cls_key in flattened_args and flattened_args[
                    f.cls_key] and args.get(f.params_key, None):
                if hasattr(
                        REGISTRIES[backend][f.module_name][flattened_args[
                            f.cls_key]], "class_or_method_args"):
                    for ff in REGISTRIES[backend][f.module_name][
                            flattened_args[f.cls_key]].class_or_method_args():
                        if isinstance(ff, Flag):
                            if ff.name in args[
                                    f.
                                    params_key] and ff.name not in flattened_args:
                                flattened_args[ff.name] = args[
                                    f.params_key].pop(ff.name)
                        else:
                            if ff.cls_key in args:
                                flattened_args[ff.cls_key] = args.pop(
                                    ff.cls_key)
                                args.pop(ff.name, None)
                            elif ff.name in args:
                                flattened_args[ff.cls_key] = args.pop(ff.name)
                            elif ff.cls_key in args[f.params_key]:
                                flattened_args[ff.cls_key] = args[
                                    f.params_key].pop(ff.cls_key)
                            elif ff.name in args[f.params_key]:
                                flattened_args[ff.cls_key] = args[
                                    f.params_key].pop(ff.name)
                            if ff.params_key in args[f.params_key]:
                                flattened_args[
                                    ff.params_key] = deep_merge_dict(
                                        args[f.params_key][ff.params_key],
                                        flattened_args.get(ff.params_key, {}))
                else:
                    flattened_args[f.params_key] = args.pop(f.params_key)
                args.pop(f.params_key, None)
    return deep_merge_dict(flattened_args, args)
Пример #3
0
def parse_flags(flag_list, arg_parser: argparse.ArgumentParser,
                args_preload_func=_args_preload_from_config_files):
    """ Parses flags from argument parser.

    Args:
        flag_list: A list of flags.
        arg_parser: The program argument parser.
        args_preload_func: A callable function for pre-loading arguments, maybe from
            config file, hyper parameter set.
    """
    program_parsed_args, remaining_argv = arg_parser.parse_known_args()
    cfg_file_args = {}
    if args_preload_func is not None:
        cfg_file_args = args_preload_func(program_parsed_args)
    program_parsed_args = yaml_load_checking(program_parsed_args.__dict__)
    top_program_parsed_args = {}
    for f in flag_list:
        flag_key = f.name
        if isinstance(f, ModuleFlag):
            flag_key = f.cls_key
            top_program_parsed_args[f.params_key] = {}
            if program_parsed_args.get(f.params_key, None) is not None:
                top_program_parsed_args[f.params_key] = program_parsed_args[f.params_key]
            if f.params_key in cfg_file_args:
                top_program_parsed_args[f.params_key] = deep_merge_dict(
                    cfg_file_args[f.params_key], top_program_parsed_args[f.params_key])
        if program_parsed_args.get(flag_key, None) is not None:
            top_program_parsed_args[flag_key] = program_parsed_args[flag_key]
        elif flag_key in cfg_file_args:
            top_program_parsed_args[flag_key] = cfg_file_args[flag_key]
        else:
            top_program_parsed_args[flag_key] = f.default

    return top_program_parsed_args, remaining_argv
Пример #4
0
def extend_define_and_parse(flag_name, args, remaining_argv, backend="tf"):
    f = _DEFINED_FLAGS.get(flag_name, None)
    if f is None or not isinstance(f, ModuleFlag):
        return args
    if not hasattr(REGISTRIES[backend][f.module_name][args[f.cls_key]], "class_or_method_args"):
        return args
    arg_parser = argparse.ArgumentParser()
    for ff in REGISTRIES[backend][f.module_name][args[f.cls_key]].class_or_method_args():
        if isinstance(ff, ModuleFlag):
            if args[f.params_key].get(ff.cls_key, None):
                this_cls = REGISTRIES[backend][ff.module_name][args[f.params_key][ff.cls_key]]
                if hasattr(this_cls, "class_or_method_args"):
                    for fff in this_cls.class_or_method_args():
                        fff.define(arg_parser)
    parsed_args, remaining_argv = arg_parser.parse_known_args(remaining_argv)
    parsed_args = yaml_load_checking(parsed_args.__dict__)
    for ff in REGISTRIES[backend][f.module_name][args[f.cls_key]].class_or_method_args():
        if isinstance(ff, ModuleFlag):
            if args[f.params_key].get(ff.cls_key, None):
                this_cls = REGISTRIES[backend][ff.module_name][args[f.params_key][ff.cls_key]]
                if hasattr(this_cls, "class_or_method_args"):
                    if args[f.params_key].get(ff.params_key, None) is None:
                        args[f.params_key][ff.params_key] = {}
                    for fff in this_cls.class_or_method_args():
                        flag_key = fff.name
                        if isinstance(fff, ModuleFlag):
                            flag_key = fff.cls_key
                        if parsed_args[flag_key] is not None:
                            args[f.params_key][ff.params_key][flag_key] = parsed_args[flag_key]
                            args.pop(flag_key, None)
                            args.pop(fff.name, None)
                        elif flag_key in args:
                            args[f.params_key][ff.params_key][flag_key] = args.pop(flag_key)
                            args.pop(fff.name, None)
                        elif fff.name in args:
                            args[f.params_key][ff.params_key][flag_key] = args.pop(fff.name)
                        elif fff.name in args[f.params_key][ff.params_key]:
                            if flag_name not in args[f.params_key][ff.params_key]:
                                args[f.params_key][ff.params_key][flag_key] = args[f.params_key][ff.params_key].pop(
                                    fff.name)
                        if isinstance(fff, ModuleFlag):
                            args[f.params_key][ff.params_key][fff.params_key] = deep_merge_dict(
                                args[f.params_key][ff.params_key].get(fff.params_key, {}) or {},
                                deep_merge_dict(args.get(fff.params_key, {}) or {},
                                                parsed_args.get(fff.params_key, {}) or {}))

    return args, remaining_argv
Пример #5
0
    def get_data_preprocess_fn(self,
                               mode,
                               data_status=compat.DataStatus.RAW,
                               args=None) -> callable:
        """ Preprocess data sample according to this task.

        Args:
            args: A dict containing dataset arguments.
            mode: A ModeKeys indicating the running mode.
            data_status: The status of the data sample.

        Returns: A callable function to collate (process) a data sample.
        """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args, local_overwrite=False)
        truncate_src = args.get("truncate_src", None)
        truncate_trg = args.get("truncate_trg", None)
        max_src_len = args.get("max_src_len", None)
        max_trg_len = args.get("max_trg_len", None)

        def _process_and_truncate(text, dp, trunc, max_len):
            if data_status != compat.DataStatus.PROJECTED:
                text = dp.process(
                    text,
                    is_processed=(data_status == compat.DataStatus.PROCESSED))
            if mode == compat.ModeKeys.TRAIN and trunc and max_len:
                if isinstance(text, tf.Tensor):
                    text = tf.cond(
                        tf.less_equal(tf.size(text), max_len), lambda: text,
                        lambda: tf.concat([text[:(max_len - 1)], text[-1:]],
                                          axis=0))
                elif len(text) > max_len:
                    text = text[:(max_len - 1)] + text[-1:]
            return text

        if mode == compat.ModeKeys.INFER:
            return lambda data: {
                "feature":
                _process_and_truncate(data["feature"], self._src_data_pipeline,
                                      truncate_src, max_src_len)
            }
        return lambda data: {
            "feature":
            _process_and_truncate(data["feature"], self._src_data_pipeline,
                                  truncate_src, max_src_len),
            "label":
            _process_and_truncate(data["label"], self._trg_data_pipeline,
                                  truncate_trg, max_trg_len)
        }
Пример #6
0
 def build_x(args, *extra_args, **kwargs):
     params_ = {}
     if isinstance(args, dict):
         cls_ = (args.get("class", None)
                 or args.get(f"{registry_name}.class", None)
                 or args.get(f"{registry_name}", None))
         params_ = (args.get("params", None) or args.get(
             "{}.params".format(registry_name), {})) or {}
     else:
         cls_ = args
     if cls_ is None:
         return None
     if isinstance(cls_, str):
         if cls_.lower() == "none":
             return None
         if cls_ not in REGISTRIES[backend][registry_name]:
             raise ValueError("Not registered class name: {}.".format(cls_))
         cls_ = REGISTRIES[backend][registry_name][cls_]
         builder = cls_
     elif callable(cls_):
         builder = cls_
     else:
         raise ValueError("Not supported type: {} for builder.".format(
             type(cls_)))
     if create_fn is not None:
         assert hasattr(builder,
                        create_fn), "{} has no {} for creation.".format(
                            cls_, create_fn)
         builder = getattr(builder, create_fn)
     if kwargs is None:
         kwargs = {}
     assert isinstance(
         params_, dict), f"Not supported type: {type(params_)} for params"
     if hasattr(cls_, "class_or_method_args"):
         for f in cls_.class_or_method_args():
             if isinstance(f, Flag):
                 if f.name in kwargs:
                     params_[f.name] = kwargs.pop(f.name)
                 elif f.name not in params_:
                     params_[f.name] = f.default
             elif isinstance(f, ModuleFlag) and f.cls_key not in params_:
                 params_[f.cls_key] = f.default
             if isinstance(f, ModuleFlag) and f.params_key not in params_:
                 params_[f.params_key] = {}
         _verbose_creation(cls_, params_, *extra_args, **kwargs)
         return builder(params_, *extra_args, **kwargs)
     params_ = deep_merge_dict(params_, kwargs, merge_only_exist=False)
     _verbose_creation(cls_, {}, *extra_args, **params_)
     return builder(*extra_args, **params_)
Пример #7
0
    def get_data_preprocess_fn(self,
                               mode,
                               data_status=compat.DataStatus.RAW,
                               args=None) -> callable:
        """ Preprocess data sample according to this task.

        Args:
            args: A dict containing dataset arguments.
            mode: A ModeKeys indicating the running mode.
            data_status: The status of the data sample.

        Returns: A callable function to collate (process) a data sample.
        """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args, local_overwrite=False)
        truncate = args.get("truncate", None)
        max_len = args.get("max_len", None)

        def _process_and_truncate(data):
            text = data["tokens"]
            if data_status != compat.DataStatus.PROJECTED:
                text = self._data_pipeline.process(
                    text,
                    is_processed=(data_status == compat.DataStatus.PROCESSED))
            if mode == compat.ModeKeys.TRAIN and truncate and max_len:
                if compat.is_tf_tensor(text):
                    text = tf.cond(
                        tf.less_equal(tf.size(text), max_len), lambda: text,
                        lambda: tf.concat([text[:(max_len - 1)], text[-1:]],
                                          axis=0))
                elif len(text) > max_len:
                    text = text[:(max_len - 1)] + text[-1:]
            return {"tokens": text}

        return _process_and_truncate
Пример #8
0
    def create_and_batch_tfds(self, ds: Dataset, mode,
                              args=None, num_replicas_in_sync=1) -> tf.data.Dataset:
        """ Creates a dataset according to the `mode`.

        Args:
            args: A dict containing dataset arguments.
            ds: A neurst.data.datasets.Dataset object.
            mode: A ModeKeys indicating the running mode.
            num_replicas_in_sync: The number of GPUs or other workers. We will generate global
                batches, and each global batch is equally divisible by number of replicas.

        Returns:
            A tf.data.Dataset or a INFER_DATA tuple.
        """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args)
        src_eos = tf.constant(self._src_data_pipeline.meta["eos_id"], dtype=tf.int64)
        trg_eos = tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64)

        assert isinstance(ds, AbstractParallelDataset), (
            "The dataset for SeqToSeq task must inherit AbstractParallelDataset.")

        dataset = ds.build(map_func=self.get_data_preprocess_fn(mode, ds.status, args),
                           map_output_dtypes=self.inputs_signature(mode)[0],
                           auto_shard=(mode == compat.ModeKeys.TRAIN),
                           shuffle=(mode == compat.ModeKeys.TRAIN))

        if mode == compat.ModeKeys.INFER:
            logging.info("Creating test dataset.")
            test_dataset = dataset_utils.batch_sequential_dataset(
                dataset=dataset.cache(),
                batch_size=args["batch_size"],
                padding_values={"feature": src_eos},
                num_replicas_in_sync=num_replicas_in_sync,
                drop_remainder=False)

            return test_dataset
        elif mode == compat.ModeKeys.EVAL:
            logging.info("Creating evaluation dataset.")
            return dataset_utils.batch_sequential_dataset(
                dataset.cache(),
                batch_size=args["batch_size"],
                padding_values={"feature": src_eos, "label": trg_eos},
                num_replicas_in_sync=num_replicas_in_sync,
                drop_remainder=False)
        else:
            logging.info("Creating training dataset.")
            if args["cache_dataset"]:
                dataset = dataset.cache()
            dataset = dataset_utils.batch_sequential_dataset(
                dataset,
                padding_values={"feature": src_eos, "label": trg_eos},
                batch_size=args["batch_size"],
                batch_size_per_gpu=args["batch_size_per_gpu"],
                batch_by_tokens=args["batch_by_tokens"],
                shuffer_buffer=args["shuffle_buffer"],
                data_max_lengths={"feature": args["max_src_len"], "label": args["max_trg_len"]},
                drop_remainder=True,
                num_replicas_in_sync=num_replicas_in_sync)
            return dataset
Пример #9
0
    def create_and_batch_tfds(self,
                              ds: Dataset,
                              mode,
                              args=None,
                              num_replicas_in_sync=1) -> tf.data.Dataset:
        """ Creates a dataset according to the `mode`.

        Args:
            args: A dict containing dataset arguments.
            ds: A neurst.data.datasets.Dataset object.
            mode: A ModeKeys indicating the running mode.
            num_replicas_in_sync: The number of GPUs or other workers. We will generate global
                batches, and each global batch is equally divisible by number of replicas.

        Returns:
            A tf.data.Dataset.
        """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args, local_overwrite=False)
        pad = tf.constant(self._data_pipeline.meta["pad_id"], dtype=tf.int64)
        dataset = ds.build(map_func=self.get_data_preprocess_fn(
            mode, ds.status, args),
                           map_output_dtypes=self.inputs_signature(mode)[0],
                           auto_shard=(mode == compat.ModeKeys.TRAIN),
                           shuffle=(mode == compat.ModeKeys.TRAIN))

        if mode == compat.ModeKeys.INFER:
            raise NotImplementedError
            # logging.info("Creating test dataset.")
            # return dataset.cache().padded_batch(
            #     dataset_utils.adjust_batch_size(args["batch_size"],
            #                                     num_replicas_in_sync=num_replicas_in_sync),
            #     padded_shapes={"tokens": [None]},
            #     padding_values={"tokens": pad},
            #     drop_remainder=False)
        elif mode == compat.ModeKeys.EVAL:
            logging.info("Creating evaluation dataset.")
            return dataset.cache().padded_batch(
                dataset_utils.adjust_batch_size(
                    args["batch_size"],
                    num_replicas_in_sync=num_replicas_in_sync),
                padded_shapes={"tokens": [None]},
                padding_values={"tokens": pad},
                drop_remainder=False)
        else:
            logging.info("Creating training dataset.")
            level = args.get("gpu_efficient_level", None)
            logging.info(
                f"Creating training dataset with GPU efficient level={level}.")
            dataset = ds.build(
                map_func=self.get_data_preprocess_fn(mode, ds.status, args),
                map_output_dtypes=self.inputs_signature(mode)[0],
                auto_shard=True,
                shuffle=True)
            dataset = dataset_utils.clean_dataset_by_length(
                dataset, {"tokens": args["max_len"]})
            if args["cache_dataset"]:
                dataset = dataset.cache()
            if args["shuffle_buffer"]:
                dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"])
            padding_values = {
                "tokens":
                tf.constant(self._data_pipeline.meta["pad_id"], dtype=tf.int64)
            }
            if args["max_len"] is None:
                raise RuntimeError("Must provide `max_len` for training.")
            max_len = minimal_multiple(args["max_len"],
                                       EFFICIENT_MULTIPLIER[level])
            batch_size = dataset_utils.adjust_batch_size(
                args["batch_size"],
                args["batch_size_per_gpu"],
                num_replicas_in_sync=num_replicas_in_sync,
                verbose=False)
            if level == GPU_EFFICIENT_LEVEL.LEVEL5:  # static batch
                _batch_size = batch_size
                if args["batch_by_tokens"]:
                    _batch_size = _batch_size // max_len
                logging.info(
                    f"Batching dataset with fixed shape: batch={_batch_size} x {max_len})."
                )
                return dataset.padded_batch(
                    _batch_size // num_replicas_in_sync * num_replicas_in_sync,
                    padded_shapes={"tokens": [max_len]},
                    padding_values=padding_values,
                    drop_remainder=True)
            else:
                bucket_boundaries = [
                    EFFICIENT_MULTIPLIER[level] * i
                    for i in range(1, max_len // EFFICIENT_MULTIPLIER[level] +
                                   1)
                ]
                if bucket_boundaries[-1] < max_len:
                    bucket_boundaries.append(
                        minimal_multiple(bucket_boundaries[-1] + 1,
                                         EFFICIENT_MULTIPLIER[level]))
                bucket_boundaries = {"tokens": bucket_boundaries}
                bucket_batch_sizes = dataset_utils.adjust_batch_size(
                    batch_size,
                    bucket_boundaries=bucket_boundaries
                    if args["batch_by_tokens"] else None,
                    boundaries_reduce_to_length_fn=lambda x: max(
                        tf.nest.flatten(x)),
                    num_replicas_in_sync=num_replicas_in_sync)
                if level != GPU_EFFICIENT_LEVEL.LEVEL0:
                    if isinstance(bucket_batch_sizes, list):
                        bucket_batch_sizes = [
                            int(
                                maximum_lower_multiple(
                                    x // num_replicas_in_sync,
                                    EFFICIENT_MULTIPLIER[level]) *
                                num_replicas_in_sync)
                            for x in bucket_batch_sizes
                        ]
                    else:
                        bucket_batch_sizes = int(
                            maximum_lower_multiple(
                                bucket_batch_sizes // num_replicas_in_sync,
                                EFFICIENT_MULTIPLIER[level]) *
                            num_replicas_in_sync)
                return dataset_utils.batch_examples_by_token(
                    dataset,
                    bucket_boundaries=bucket_boundaries,
                    bucket_batch_sizes=bucket_batch_sizes,
                    padding_values=padding_values,
                    example_length_func=lambda x:
                    {k: tf.size(v)
                     for k, v in x.items()})
Пример #10
0
    def create_and_batch_tfds(self,
                              ds: Dataset,
                              mode,
                              args=None,
                              num_replicas_in_sync=1) -> tf.data.Dataset:
        """ Creates a dataset according to the `mode`.

        Args:
            args: A dict containing dataset arguments.
            ds: A neurst.data.datasets.Dataset object.
            mode: A ModeKeys indicating the running mode.
            num_replicas_in_sync: The number of GPUs or other workers. We will generate global
                batches, and each global batch is equally divisible by number of replicas.

        Returns:
            A tf.data.Dataset.
        """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args, local_overwrite=False)
        src_eos = tf.constant(self._src_data_pipeline.meta["eos_id"],
                              dtype=tf.int64)
        trg_eos = tf.constant(self._trg_data_pipeline.meta["eos_id"],
                              dtype=tf.int64)

        assert isinstance(ds, AbstractParallelDataset), (
            "The dataset for SeqToSeq task must inherit AbstractParallelDataset."
        )

        dataset = ds.build(map_func=self.get_data_preprocess_fn(
            mode, ds.status, args),
                           map_output_dtypes=self.inputs_signature(mode)[0],
                           auto_shard=(mode == compat.ModeKeys.TRAIN),
                           shuffle=(mode == compat.ModeKeys.TRAIN))

        if mode == compat.ModeKeys.INFER:
            logging.info("Creating test dataset.")
            return dataset.cache().padded_batch(
                dataset_utils.adjust_batch_size(
                    args["batch_size"],
                    num_replicas_in_sync=num_replicas_in_sync),
                padded_shapes={"feature": [None]},
                padding_values={"feature": src_eos},
                drop_remainder=False)
        elif mode == compat.ModeKeys.EVAL:
            logging.info("Creating evaluation dataset.")
            return dataset.cache().padded_batch(
                dataset_utils.adjust_batch_size(
                    args["batch_size"],
                    num_replicas_in_sync=num_replicas_in_sync),
                padded_shapes={
                    "feature": [None],
                    "label": [None]
                },
                padding_values={
                    "feature": src_eos,
                    "label": trg_eos
                },
                drop_remainder=False)
        else:
            logging.info("Creating training dataset.")
            dataset = dataset_utils.clean_dataset_by_length(
                dataset, {
                    "feature": args["max_src_len"],
                    "label": args["max_trg_len"]
                })
            if args["cache_dataset"]:
                dataset = dataset.cache()
            if args["shuffle_buffer"]:
                dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"])
            padding_values = {"feature": src_eos, "label": trg_eos}
            if args["max_src_len"] is None:
                raise RuntimeError("Must provide `max_src_len` for training.")
            if args["max_trg_len"] is None:
                raise RuntimeError("Must provide `max_trg_len` for training.")
            src_bucket_boundaries, trg_bucket_boundaries = dataset_utils.associated_bucket_boundaries(
                dataset_utils.create_batch_bucket_boundaries(
                    args["max_src_len"]),
                dataset_utils.create_batch_bucket_boundaries(
                    args["max_trg_len"]))

            bucket_boundaries = {
                "feature": src_bucket_boundaries,
                "label": trg_bucket_boundaries
            }
            bucket_batch_sizes = dataset_utils.adjust_batch_size(
                args["batch_size"],
                args["batch_size_per_gpu"],
                bucket_boundaries=bucket_boundaries
                if args["batch_by_tokens"] else None,
                boundaries_reduce_to_length_fn=lambda x: max(tf.nest.flatten(x)
                                                             ),
                num_replicas_in_sync=num_replicas_in_sync)
            return dataset_utils.batch_examples_by_token(
                dataset,
                bucket_boundaries=bucket_boundaries,
                bucket_batch_sizes=bucket_batch_sizes,
                padding_values=padding_values,
                example_length_func=lambda x:
                {k: tf.size(v)
                 for k, v in x.items()})
Пример #11
0
    def create_and_batch_tfds(self,
                              ds: Dataset,
                              mode,
                              args=None,
                              num_replicas_in_sync=1) -> tf.data.Dataset:
        """ Creates a dataset according to the `mode`.

        Args:
            args: A dict containing dataset arguments.
            ds: A neurst.data.datasets.Dataset object.
            mode: A ModeKeys indicating the running mode.
            num_replicas_in_sync: The number of GPUs or other workers. We will generate global
                batches, and each global batch is equally divisible by number of replicas.

        Returns:
            A tf.data.Dataset.
        """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args, local_overwrite=False)
        float_zero = tf.constant(0, dtype=tf.float32)
        int_zero = tf.constant(0, dtype=tf.int64)
        trg_eos = tf.constant(self._trg_data_pipeline.meta["eos_id"],
                              dtype=tf.int64)

        dataset = ds.build(map_func=self.get_data_preprocess_fn(
            mode, ds.status, args),
                           map_output_dtypes=self.inputs_signature(mode)[0],
                           auto_shard=(mode == compat.ModeKeys.TRAIN),
                           shuffle=(mode == compat.ModeKeys.TRAIN))

        if mode == compat.ModeKeys.INFER:
            logging.info("Creating test dataset.")
            return dataset.cache().padded_batch(
                dataset_utils.adjust_batch_size(
                    args["batch_size"],
                    num_replicas_in_sync=num_replicas_in_sync),
                padded_shapes={
                    "audio": [None],
                    "audio_length": []
                },
                padding_values={
                    "audio": float_zero,
                    "audio_length": int_zero
                },
                drop_remainder=False)

        elif mode == compat.ModeKeys.EVAL:
            logging.info("Creating evaluation dataset.")
            return dataset.cache().padded_batch(
                dataset_utils.adjust_batch_size(
                    args["batch_size"],
                    num_replicas_in_sync=num_replicas_in_sync),
                padded_shapes={
                    "audio": [None],
                    "audio_length": [],
                    "transcript": [None]
                },
                padding_values={
                    "audio": float_zero,
                    "audio_length": int_zero,
                    "transcript": trg_eos
                },
                drop_remainder=False)
        else:
            logging.info("Creating training dataset.")
            dataset = dataset_utils.clean_dataset_by_length(
                dataset, {
                    "audio": args["max_src_len"] * self._audio_feature_dim *
                    self._audio_feature_channels,
                    "audio_length": -1,
                    "transcript": args["max_trg_len"]
                })
            if args["cache_dataset"]:
                dataset = dataset.cache()
            if args["shuffle_buffer"]:
                dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"])
            padding_values = {
                "audio": float_zero,
                "audio_length": int_zero,
                "transcript": trg_eos
            }
            if args["max_src_len"] is None:
                raise RuntimeError(
                    "`max_src_len` for SpeechToText task must be provided.")
            if args["max_trg_len"] is None:
                raise RuntimeError(
                    "`max_trg_len` for SpeechToText task must be provided.")
            max_src_len = args["max_src_len"]
            max_trg_len = minimal_multiple(args["max_trg_len"], 8)
            audio_bucket_boundaries = create_audio_bucket_boundaries(
                max_src_len, args["min_src_bucket_boundary"])
            audio_bucket_boundaries[-1] = minimal_multiple(
                audio_bucket_boundaries[-1], 8)
            batch_size = dataset_utils.adjust_batch_size(
                args["batch_size"],
                args["batch_size_per_gpu"],
                num_replicas_in_sync=num_replicas_in_sync,
                verbose=False)
            batch_size_per_gpu = batch_size // num_replicas_in_sync
            assert batch_size_per_gpu > max_src_len, (
                f"batch size per gpu({batch_size_per_gpu} must be greater than "
                f"`max_src_len`={max_src_len}")
            if args["disable_batch_efficiency"]:
                bucket_batch_sizes = [
                    int(batch_size_per_gpu // bound * num_replicas_in_sync)
                    for bound in audio_bucket_boundaries
                ]
            else:
                bucket_batch_sizes = [
                    int(
                        minimal_multiple(batch_size_per_gpu // bound, 8) *
                        num_replicas_in_sync)
                    for bound in audio_bucket_boundaries
                ]
            frame_transcript_ratio = args[
                "experimental_frame_transcript_ratio"]
            if frame_transcript_ratio is None:
                logging.info(
                    "WARNING: we recommend one to pre-scan the dataset and estimate the ratio: "
                    "frame length / transcript length.")
            else:
                trans_bucket_boundaries = [
                    int(bound /
                        (frame_transcript_ratio + i *
                         (max_src_len / max_trg_len - frame_transcript_ratio) /
                         len(audio_bucket_boundaries)))
                    for i, bound in enumerate(audio_bucket_boundaries)
                ]
                trans_bucket_boundaries = [
                    minimal_multiple(min(i, max_trg_len), 8)
                    for i in trans_bucket_boundaries
                ]
                num_buckets = len(trans_bucket_boundaries)
                true_trans_bucket_boundaries = []
                num_input_shapes = 0
                for idx, (batc, bound, tbound) in enumerate(
                        zip(bucket_batch_sizes, audio_bucket_boundaries,
                            trans_bucket_boundaries)):
                    max_trans_len = [
                        tbound, trans_bucket_boundaries[min(
                            idx + 1,
                            len(bucket_batch_sizes) - 1)]
                    ]
                    num_input_shapes += len(set(max_trans_len))
                    true_trans_bucket_boundaries.append(max_trans_len)
                logging.info(
                    f"There are {num_input_shapes} input shapes to be compiled:"
                )
                for idx, (batc, bound, tbound) in enumerate(
                        zip(bucket_batch_sizes, audio_bucket_boundaries,
                            true_trans_bucket_boundaries)):
                    logging.info(f"   - batch={batc}, maximum-frames={bound}, "
                                 f"maximum-transcript-length={set(tbound)}")
                true_trans_bucket_boundaries = tf.constant(
                    true_trans_bucket_boundaries, dtype=tf.int32)
                true_audio_bucket_boundaries = tf.transpose(
                    tf.constant([audio_bucket_boundaries] * 2, dtype=tf.int32))

            bucket_batch_sizes = tf.constant(bucket_batch_sizes,
                                             dtype=tf.int64)
            audio_bucket_boundaries = tf.constant(audio_bucket_boundaries,
                                                  dtype=tf.int32)

            def example_to_bucket_id(examples):
                """Return int64 bucket id for this example, calculated based on length."""
                if frame_transcript_ratio is None:
                    conditions_c = tf.less_equal(
                        tf.cast(examples["audio_length"], tf.int32),
                        audio_bucket_boundaries)
                    return tf.reduce_min(tf.where(conditions_c))
                conditions_c = tf.logical_and(
                    tf.less_equal(tf.cast(examples["audio_length"], tf.int32),
                                  true_audio_bucket_boundaries),
                    tf.less_equal(tf.size(examples["transcript"]),
                                  true_trans_bucket_boundaries))
                minimum_match = tf.where(conditions_c)[0]
                return minimum_match[0] * num_buckets + minimum_match[1]

            def window_size_fn(bucket_id):
                """Return number of examples to be grouped when given a bucket id."""
                if frame_transcript_ratio is None:
                    return bucket_batch_sizes[bucket_id]
                return bucket_batch_sizes[bucket_id // num_buckets]

            def batching_fn(bucket_id, grouped_dataset):
                """Batch and add padding to a dataset of elements with similar lengths."""
                bucket_batch_size = window_size_fn(bucket_id)

                # Batch the dataset and add padding so that all input sequences in the
                # examples have the same length, and all target sequences have the same
                # lengths as well. Resulting lengths of inputs and targets can differ.
                return grouped_dataset.padded_batch(
                    bucket_batch_size,
                    padded_shapes={
                        "audio":
                        ([(audio_bucket_boundaries[bucket_id]
                           if frame_transcript_ratio is None else
                           audio_bucket_boundaries[bucket_id // num_buckets]) *
                          self._audio_feature_dim *
                          self._audio_feature_channels]),
                        "audio_length": [],
                        "transcript":
                        ([None] if frame_transcript_ratio is None else [
                            true_trans_bucket_boundaries[
                                bucket_id // num_buckets][bucket_id %
                                                          num_buckets]
                        ])
                    },
                    padding_values=padding_values,
                    drop_remainder=True)

            return dataset.apply(
                tf.data.experimental.group_by_window(
                    key_func=example_to_bucket_id,
                    reduce_func=batching_fn,
                    window_size=None,
                    window_size_func=window_size_fn))
Пример #12
0
    def get_data_preprocess_fn(self, mode, data_status, args=None) -> callable:
        """ Preprocess data sample according to this task.

        Args:
            args: A dict containing dataset arguments.
            mode: A ModeKeys indicating the running mode.
            data_status: The status of the data sample.

        Returns: A callable function to collate (process) a data sample.
        """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args, local_overwrite=False)
        trunc_audio = args.get("truncate_src", None)
        max_audio_len = args.get("max_src_len", None)
        trunc_trg = args.get("truncate_trg", None)
        max_trg_len = args.get("max_trg_len", None)

        if data_status["audio"] != compat.DataStatus.PROJECTED:
            raise RuntimeError(
                "We recommend one to preprocess the audio in advance.")

        def _process_audio(audio):
            if trunc_audio and max_audio_len:
                audio = audio[:max_audio_len * self._audio_feature_dim *
                              self._audio_feature_channels]
            if self._specaug is not None:
                audio = tf.reshape(audio, [
                    -1, self._audio_feature_dim * self._audio_feature_channels
                ])
                audio = tf.reshape(self._specaug(audio), [-1])
            return audio

        def _process_and_truncate_text(text):
            if data_status["transcript"] == compat.DataStatus.RAW:
                text = self._trg_data_pipeline.process(text,
                                                       is_processed=False)
            else:
                assert data_status["transcript"] == compat.DataStatus.PROJECTED
            if mode == compat.ModeKeys.TRAIN and trunc_trg and max_trg_len:
                if isinstance(text, tf.Tensor):
                    text = tf.cond(
                        tf.less_equal(tf.size(text), max_trg_len),
                        lambda: text, lambda: tf.concat(
                            [text[:(max_trg_len - 1)], text[-1:]], axis=0))
                else:
                    if len(text) > max_trg_len:
                        text = text[:(max_trg_len - 1)] + text[-1:]
            return text

        def data_proc(data, with_label):
            feature = _process_audio(data["audio"])
            ret = {
                "audio":
                feature,
                "audio_length":
                tf.cast(
                    (tf.shape(feature)[0] if isinstance(feature, tf.Tensor)
                     else feature.shape[0]) // self._audio_feature_dim //
                    self._audio_feature_channels,
                    dtype=tf.int64)
            }
            if with_label:
                ret["transcript"] = tf.convert_to_tensor(
                    _process_and_truncate_text(data["transcript"]), tf.int64)
            return ret

        if mode == compat.ModeKeys.INFER:
            return lambda data: data_proc(data, False)
        return lambda data: data_proc(data, True)
Пример #13
0
def intelligent_parse_flags(flag_list, arg_parser: argparse.ArgumentParser,
                            args_preload_func=_args_preload_from_config_files,
                            backend="tf"):
    """ Parses flags from argument parser.

    Args:
        flag_list: A list of flags.
        arg_parser: The program argument parser.
        args_preload_func: A callable function for pre-loading arguments, maybe from
            config file, hyper parameter set.
        backend: The DL backend.
    """
    program_parsed_args, remaining_argv = arg_parser.parse_known_args()
    cfg_file_args = {}
    if args_preload_func is not None:
        cfg_file_args = args_preload_func(program_parsed_args)
    top_program_parsed_args = _flatten_args(flag_list,
                                            yaml_load_checking(program_parsed_args.__dict__))
    for f in flag_list:
        if isinstance(f, ModuleFlag):
            if f.cls_key in top_program_parsed_args and top_program_parsed_args[f.cls_key]:
                cfg_file_args[f.cls_key] = top_program_parsed_args[f.cls_key]
    cfg_file_args = _flatten_args(flag_list, cfg_file_args)
    for f in flag_list:
        if isinstance(f, Flag):
            if top_program_parsed_args[f.name] is None:
                top_program_parsed_args[f.name] = cfg_file_args.get(f.name, None)
            cfg_file_args.pop(f.name, None)
        else:
            submodule_cls = (top_program_parsed_args.get(f.cls_key, None)
                             or cfg_file_args.get(f.cls_key, None))
            cfg_file_args.pop(f.cls_key, None)
            if submodule_cls is None:
                continue
            top_program_parsed_args[f.cls_key] = submodule_cls
            if top_program_parsed_args.get(f.params_key, None) is None:
                top_program_parsed_args[f.params_key] = {}
            module_arg_parser = get_argparser(f.module_name, submodule_cls)
            module_parsed_args, remaining_argv = module_arg_parser.parse_known_args(remaining_argv)
            module_parsed_args = yaml_load_checking(module_parsed_args.__dict__)

            if hasattr(REGISTRIES[backend][f.module_name][submodule_cls], "class_or_method_args"):
                key_cfg_file_args = _flatten_args(
                    REGISTRIES[backend][f.module_name][submodule_cls].class_or_method_args(), cfg_file_args)
                for inner_f in REGISTRIES[backend][f.module_name][submodule_cls].class_or_method_args():
                    flag_key = inner_f.name
                    if isinstance(inner_f, ModuleFlag):
                        flag_key = inner_f.cls_key
                        cfg_file_args.pop(flag_key, None)
                    if module_parsed_args[flag_key] is not None:
                        top_program_parsed_args[f.params_key][flag_key] = module_parsed_args[flag_key]
                        top_program_parsed_args.pop(flag_key, None)
                        key_cfg_file_args.pop(flag_key, None)
                        cfg_file_args.pop(flag_key, None)
                    elif flag_key in top_program_parsed_args:
                        top_program_parsed_args[f.params_key][flag_key] = top_program_parsed_args.pop(flag_key)
                        key_cfg_file_args.pop(flag_key, None)
                        cfg_file_args.pop(flag_key, None)
                    elif flag_key in key_cfg_file_args:
                        top_program_parsed_args[f.params_key][flag_key] = key_cfg_file_args.pop(flag_key)
                        cfg_file_args.pop(flag_key, None)

                    if isinstance(inner_f, ModuleFlag):
                        top_program_parsed_args[f.params_key][inner_f.params_key] = deep_merge_dict(
                            cfg_file_args.pop(inner_f.params_key, {}) or {},
                            deep_merge_dict(top_program_parsed_args[f.params_key].pop(inner_f.params_key, {}) or {},
                                            deep_merge_dict(top_program_parsed_args.pop(inner_f.params_key, {}) or {},
                                                            module_parsed_args.pop(inner_f.params_key, {}) or {})))
    top_program_parsed_args = deep_merge_dict(cfg_file_args, top_program_parsed_args)
    for f in flag_list:
        if isinstance(f, Flag):
            if f.name not in top_program_parsed_args or top_program_parsed_args[f.name] is None:
                top_program_parsed_args[f.name] = f.default
    return top_program_parsed_args, remaining_argv
Пример #14
0
 def create_and_batch_tfds(self, ds, mode, args=None, num_replicas_in_sync=1):
     """ With efficient level for training. """
     if mode in [compat.ModeKeys.INFER, compat.ModeKeys.EVAL]:
         return super(Translation, self).create_and_batch_tfds(
             ds, mode, args, num_replicas_in_sync)
     if args is None:
         args = self._args
     else:
         args = deep_merge_dict(self._args, args, local_overwrite=False)
     level = args.get("gpu_efficient_level", None)
     auto_scale_batch = args.get("auto_scaling_batch_size", None)
     logging.info(f"Creating training dataset with GPU efficient level={level}.")
     dataset = ds.build(map_func=self.get_data_preprocess_fn(mode, ds.status, args),
                        map_output_dtypes=self.inputs_signature(mode)[0],
                        auto_shard=True, shuffle=True)
     dataset = dataset_utils.clean_dataset_by_length(
         dataset, {"feature": args["max_src_len"], "label": args["max_trg_len"]})
     if args["cache_dataset"]:
         dataset = dataset.cache()
     if args["shuffle_buffer"]:
         dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"])
     padding_values = {"feature": tf.constant(self._src_data_pipeline.meta["eos_id"], dtype=tf.int64),
                       "label": tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64)}
     if args["max_src_len"] is None:
         raise RuntimeError("Must provide `max_src_len` for training.")
     if args["max_trg_len"] is None:
         raise RuntimeError("Must provide `max_trg_len` for training.")
     max_src_len = minimal_multiple(args["max_src_len"], EFFICIENT_MULTIPLIER[level])
     max_trg_len = minimal_multiple(args["max_trg_len"], EFFICIENT_MULTIPLIER[level])
     max_len = max(max_src_len, max_trg_len)
     batch_size = dataset_utils.adjust_batch_size(args["batch_size"], args["batch_size_per_gpu"],
                                                  num_replicas_in_sync=num_replicas_in_sync,
                                                  verbose=False)
     if auto_scale_batch:
         batch_size = _auto_scale_batch_size(batch_size, level)
         logging.info(f"Auto scaling batch size to {batch_size}.")
     if level == GPU_EFFICIENT_LEVEL.LEVEL5:  # static batch
         _batch_size = batch_size
         if args["batch_by_tokens"]:
             _batch_size = _batch_size // max_len
         logging.info("Batching dataset with fixed shape: "
                      f"batch={_batch_size} x (feature={max_src_len}, label={max_trg_len}).")
         return dataset.padded_batch(
             _batch_size // num_replicas_in_sync * num_replicas_in_sync,
             padded_shapes={"feature": [max_src_len], "label": [max_trg_len]},
             drop_remainder=True, padding_values=padding_values)
     else:
         src_bucket_boundaries = [EFFICIENT_MULTIPLIER[level] * i for i in
                                  range(1, max_src_len // EFFICIENT_MULTIPLIER[level] + 1)]
         if src_bucket_boundaries[-1] < max_src_len:
             src_bucket_boundaries.append(minimal_multiple(src_bucket_boundaries[-1] + 1,
                                                           EFFICIENT_MULTIPLIER[level]))
         trg_bucket_boundaries = [EFFICIENT_MULTIPLIER[level] * i for i in
                                  range(1, max_trg_len // EFFICIENT_MULTIPLIER[level] + 1)]
         if trg_bucket_boundaries[-1] < max_trg_len:
             trg_bucket_boundaries.append(minimal_multiple(trg_bucket_boundaries[-1] + 1,
                                                           EFFICIENT_MULTIPLIER[level]))
         src_bucket_boundaries, trg_bucket_boundaries = dataset_utils.associated_bucket_boundaries(
             src_bucket_boundaries, trg_bucket_boundaries)
         bucket_boundaries = {
             "feature": src_bucket_boundaries,
             "label": trg_bucket_boundaries
         }
         bucket_batch_sizes = dataset_utils.adjust_batch_size(
             batch_size,
             bucket_boundaries=bucket_boundaries if args["batch_by_tokens"] else None,
             boundaries_reduce_to_length_fn=lambda x: max(tf.nest.flatten(x)),
             num_replicas_in_sync=num_replicas_in_sync)
         if level != GPU_EFFICIENT_LEVEL.LEVEL0:
             if isinstance(bucket_batch_sizes, list):
                 bucket_batch_sizes = [
                     int(maximum_lower_multiple(x // num_replicas_in_sync,
                                                EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync)
                     for x in bucket_batch_sizes]
             else:
                 bucket_batch_sizes = int(maximum_lower_multiple(
                     bucket_batch_sizes // num_replicas_in_sync,
                     EFFICIENT_MULTIPLIER[level]) * num_replicas_in_sync)
         return dataset_utils.batch_examples_by_token(
             dataset,
             bucket_boundaries=bucket_boundaries,
             bucket_batch_sizes=bucket_batch_sizes,
             padding_values=padding_values,
             example_length_func=lambda x: {k: tf.size(v) for k, v in x.items()}
         )
Пример #15
0
    def create_and_batch_tfds(self,
                              ds,
                              mode,
                              args=None,
                              num_replicas_in_sync=1):
        """ With efficient level for training. """
        if args is None:
            args = self._args
        else:
            args = deep_merge_dict(self._args, args)
        level = args.get("gpu_efficient_level", None)
        auto_scale_batch = args.get("auto_scaling_batch_size", None)
        if (mode in [compat.ModeKeys.INFER, compat.ModeKeys.EVAL]
                or level is None or level == GPU_EFFICIENT_LEVEL.LEVEL0):
            return super(Translation,
                         self).create_and_batch_tfds(ds, mode, args,
                                                     num_replicas_in_sync)
        padding_values = {
            "feature":
            tf.constant(self._src_data_pipeline.meta["eos_id"],
                        dtype=tf.int64),
            "label":
            tf.constant(self._trg_data_pipeline.meta["eos_id"], dtype=tf.int64)
        }
        dataset = ds.build(auto_shard=True,
                           map_func=self.get_data_preprocess_fn(
                               mode, ds.status, args),
                           map_output_dtypes=self.inputs_signature(mode)[0])
        max_src_len = args["max_src_len"]
        max_trg_len = args["max_trg_len"]
        batch_by_tokens = args["batch_by_tokens"]
        assert max_src_len, "Must provide `max_src_len` when `gpu_efficient_level` > 0"
        assert max_trg_len, "Must provide `max_trg_len` when `gpu_efficient_level` > 0"
        logging.info(
            f"Creating training dataset with `gpu_efficient_level`={level}.")
        dataset = clean_dataset_by_length(dataset, {
            "feature": max_src_len,
            "label": max_trg_len
        })
        if args["cache_dataset"]:
            dataset = dataset.cache()
        if args["shuffle_buffer"]:
            dataset = dataset.shuffle(buffer_size=args["shuffle_buffer"])
        batch_size_per_gpu = args["batch_size_per_gpu"]
        batch_size = args["batch_size"]
        if batch_size_per_gpu is None:
            batch_size_per_gpu = batch_size // num_replicas_in_sync
        if batch_by_tokens:
            assert batch_size_per_gpu > max(max_src_len, max_trg_len), (
                f"batch size per gpu({batch_size_per_gpu} must be greater than "
                f"both `max_src_len`{max_src_len} and `max_trg_len`{max_trg_len}"
            )
        if auto_scale_batch:
            new_batch_size_per_gpu = _auto_scale_batch_size(
                batch_size_per_gpu, level)
            logging.info(
                f"Auto scaling `batch_size_per_gpu` from {batch_size_per_gpu} "
                f"to {new_batch_size_per_gpu}")
            batch_size_per_gpu = new_batch_size_per_gpu
        max_src_len = minimal_multiple(max_src_len,
                                       EFFICIENT_MULTIPLIER[level])
        max_trg_len = minimal_multiple(max_trg_len,
                                       EFFICIENT_MULTIPLIER[level])
        max_len = max(max_src_len, max_trg_len)
        if level == GPU_EFFICIENT_LEVEL.LEVEL5:  # static batch
            if batch_by_tokens:
                batch_size_per_gpu = batch_size_per_gpu // max_len
            return dataset.padded_batch(int(
                minimal_multiple(batch_size_per_gpu,
                                 EFFICIENT_MULTIPLIER[level]) *
                num_replicas_in_sync),
                                        padded_shapes={
                                            "feature": [max_src_len],
                                            "label": [max_trg_len]
                                        },
                                        drop_remainder=True,
                                        padding_values=padding_values)
        else:
            bucket_boundaries = [
                EFFICIENT_MULTIPLIER[level] * i
                for i in range(1, max_len // EFFICIENT_MULTIPLIER[level] + 1)
            ]
            if bucket_boundaries[-1] < max_len:
                bucket_boundaries.append(
                    minimal_multiple(bucket_boundaries[-1] + 1,
                                     EFFICIENT_MULTIPLIER[level]))
            buckets_min = [0] + bucket_boundaries[:-1]
            if batch_by_tokens:
                bucket_batch_sizes = [
                    int(
                        minimal_multiple(batch_size_per_gpu // bound,
                                         EFFICIENT_MULTIPLIER[level]) *
                        num_replicas_in_sync) for bound in bucket_boundaries
                ]
            else:
                bucket_batch_sizes = [
                    int(
                        minimal_multiple(batch_size_per_gpu,
                                         EFFICIENT_MULTIPLIER[level]) *
                        num_replicas_in_sync)
                ] * len(bucket_boundaries)

            logging.info(
                f"There are {len(bucket_batch_sizes)} input shapes to be compiled:"
            )
            for batc, bound in zip(bucket_batch_sizes, bucket_boundaries):
                logging.info(f"   - batch={batc}, maximum-length={bound}")
            bucket_batch_sizes = tf.constant(bucket_batch_sizes,
                                             dtype=tf.int64)
            bucket_boundaries = tf.constant(bucket_boundaries, dtype=tf.int32)

            def example_to_bucket_id(examples):
                """Return int64 bucket id for this example, calculated based on length."""
                seq_length = tf.cast(
                    tf.maximum(tf.size(examples["feature"]),
                               tf.size(examples["label"])), tf.int32)

                conditions_c = tf.logical_and(
                    tf.less(buckets_min, seq_length),
                    tf.less_equal(seq_length, bucket_boundaries))
                bucket_id = tf.reduce_min(tf.where(conditions_c))
                return bucket_id

            def window_size_fn(bucket_id):
                """Return number of examples to be grouped when given a bucket id."""
                return bucket_batch_sizes[bucket_id]

            def batching_fn(bucket_id, grouped_dataset):
                """Batch and add padding to a dataset of elements with similar lengths."""
                bucket_batch_size = window_size_fn(bucket_id)

                # Batch the dataset and add padding so that all input sequences in the
                # examples have the same length, and all target sequences have the same
                # lengths as well. Resulting lengths of inputs and targets can differ.
                return grouped_dataset.padded_batch(
                    bucket_batch_size,
                    padded_shapes={
                        "feature": [bucket_boundaries[bucket_id]],
                        "label": [bucket_boundaries[bucket_id]]
                    },
                    padding_values=padding_values,
                    drop_remainder=True)

            return dataset.apply(
                tf.data.experimental.group_by_window(
                    key_func=example_to_bucket_id,
                    reduce_func=batching_fn,
                    window_size=None,
                    window_size_func=window_size_fn))