def __init__(self, args): """ Initializes the dataset. """ super(MixedTrainDataset, self).__init__() self._data_files = args["data_files"] if isinstance(self._data_files, str): self._data_files = yaml.load(args["data_files"], Loader=yaml.FullLoader) assert isinstance(self._data_files, dict) self._data_sampler = build_data_sampler(args) common_properties = args["common_properties"] if common_properties is None: common_properties = {} elif isinstance(common_properties, str): common_properties = yaml.load(common_properties, Loader=yaml.FullLoader) assert isinstance(common_properties, dict) self._custom_dss = dict() self._status = None for name, ds in self._data_files.items(): self._custom_dss[name] = build_dataset(args["data_class"], **ds, **common_properties) if self._status is None: self._status = self._custom_dss[name].status else: assert self._status == self._custom_dss[name].status, ( "Status of each dataset are supposed to be the same.") self._data_sampler = build_data_sampler(args)
def run_experiment(args, remaining_argv): strategy = training_utils.handle_distribution_strategy( args["distribution_strategy"]) flags_core.verbose_flags(FLAG_LIST, args, remaining_argv) training_utils.startup_env( dtype=args["dtype"], enable_check_numerics=args["enable_check_numerics"], enable_xla=args["enable_xla"]) # initialize parameters for quantization. if args.get("quant_params", None) is None: args["quant_params"] = {} QuantLayer.global_init(args["enable_quant"], **args["quant_params"]) # create exps: trainer, evaluator or ... with training_utils.get_strategy_scope(strategy): task = build_task(args) custom_dataset = build_dataset(args) try: model = task.build_model(args) training_utils.validate_unique_varname(model.weights) except AttributeError: model = None entry = build_exp(args, strategy=strategy, model=model, task=task, model_dir=args["model_dir"], custom_dataset=custom_dataset) entry.run()
def __init__(self, args): """ Initializes the multiple dataset. Args: args: containing `multiple_dataset`, which is like { "data0": { "dataset.class": "", "dataset.params": ""}, "data1": { "dataset.class": "", "dataset.params": ""}, ...... ] """ super(MultipleDataset, self).__init__() self._datasets = {name: build_dataset(dsargs) for name, dsargs in args["multiple_datasets"].items()} self._sample_weights = dict() if args["sample_weights"]: assert isinstance(args["sample_weights"], dict) else: args["sample_weights"] = {} sum = 0. for name in self._datasets: self._sample_weights[name] = args["sample_weights"].get(name, 1.) sum += self._sample_weights[name] for name in self._datasets: self._sample_weights[name] /= sum
def _main(_): # define and parse program flags arg_parser = flags_core.define_flags(FLAG_LIST, with_config_file=True) args, remaining_argv = flags_core.intelligent_parse_flags( FLAG_LIST, arg_parser) flags_core.verbose_flags(FLAG_LIST, args, remaining_argv) dataset = build_dataset(args) feature_extractor = build_feature_extractor(args) if dataset is None: raise ValueError("dataset must be provided.") main(dataset, feature_extractor)
def _main(_): # define and parse program flags arg_parser = flags_core.define_flags(FLAG_LIST, with_config_file=True) args, remaining_argv = flags_core.intelligent_parse_flags( FLAG_LIST, arg_parser) flags_core.verbose_flags(FLAG_LIST, args, remaining_argv) dataset = build_dataset(args) if dataset is None: raise ValueError("dataset must be provided.") main(dataset=dataset, output_transcript_file=args["output_transcript_file"], output_translation_file=args["output_translation_file"])
def build(self, strategy, task, model): """ Initializes. """ self._strategy = strategy self._criterion: Criterion = build_criterion( self.args["eval_criterion.class"], **self.args["eval_criterion.params"]) self._criterion.set_model(model) if self._criterion is None: logging.info( "WARNING: no criterion is provided in CriterionValidator " "for validation process.") self._validate_criterion = False return self self._custom_dataset = build_dataset( self.args["eval_dataset.class"], **self.args["eval_dataset.params"]) if self._custom_dataset is None: logging.info("WARNING: no validation dataset is provided " "in CriterionValidator for validation process.") self._validate_criterion = False return self from neurst.exps.evaluator import Evaluator with training_utils.get_strategy_scope(strategy): self._criterion_model = Evaluator.build_evaluation_model( task, model, self._criterion) self._eval_tfds = training_utils.build_datasets( compat.ModeKeys.EVAL, strategy, self._custom_dataset, task, True, self._eval_task_args) self._criterion_metric = self._criterion.as_metric() if isinstance(self._custom_dataset, MultipleDataset): self._criterion_recorder = { name: training_utils.TrainingStatusRecorder( model=model, task=task, metric=self._criterion_metric) for name in self._custom_dataset.datasets } self._avg_criterion_recorder = training_utils.TrainingStatusRecorder( model=model, task=task, metric=self._criterion_metric) self._mixed_criterion_recorder = training_utils.TrainingStatusRecorder( model=model, task=task, metric=self._criterion_metric) else: self._criterion_recorder = training_utils.TrainingStatusRecorder( model=model, task=task, metric=self._criterion_metric) self._criterion_start_time = time.time() return self
def _main(_): # define and parse program flags arg_parser = flags_core.define_flags(FLAG_LIST, with_config_file=True) args, remaining_argv = flags_core.intelligent_parse_flags( FLAG_LIST, arg_parser) flags_core.verbose_flags(FLAG_LIST, args, remaining_argv) task = build_task(args) dataset = build_dataset(args) if dataset is None: raise ValueError("dataset must be provided.") main(processor_id=args["processor_id"], num_processors=args["num_processors"], num_output_shards=args["num_output_shards"], output_range_begin=args["output_range_begin"], output_range_end=args["output_range_end"], output_template=args["output_template"], progressbar=args["progressbar"], dataset=dataset, task=task)
def run_experiment(args, remaining_argv): strategy = training_utils.handle_distribution_strategy( args["distribution_strategy"]) flags_core.verbose_flags(FLAG_LIST, args, remaining_argv) training_utils.startup_env( dtype=args["dtype"], enable_check_numerics=args["enable_check_numerics"], enable_xla=args["enable_xla"]) # create exps: trainer, evaluator or ... with training_utils.get_strategy_scope(strategy): task = build_task(args) custom_dataset = build_dataset(args) try: model = task.build_model(args) except AttributeError: model = None entry = build_exp(args, strategy=strategy, model=model, task=task, model_dir=args["model_dir"], custom_dataset=custom_dataset) entry.run()
def _main(_): # define and parse program flags arg_parser = flags_core.define_flags(FLAG_LIST) args, remaining_argv = flags_core.parse_flags(FLAG_LIST, arg_parser) flags_core.verbose_flags(FLAG_LIST, args, remaining_argv) strategy = training_utils.handle_distribution_strategy( args["distribution_strategy"]) training_utils.startup_env( dtype=args["dtype"], enable_xla=False, enable_check_numerics=args["enable_check_numerics"]) asr_task, asr_model = _build_task_model(strategy, args["asr_model_dir"], batch_size=args["batch_size"]) mt_task, mt_model = _build_task_model(strategy, args["mt_model_dir"], batch_size=args["batch_size"]) audio_dataset = build_dataset(args) # ========= ASR ========== asr_output_file = args["asr_output_file"] if asr_output_file is None: asr_output_file = "ram://asr_output_file" logging.info("Creating ASR generator.") with training_utils.get_strategy_scope(strategy): asr_generator = build_exp( { "class": SequenceGenerator, "params": { "output_file": asr_output_file, "search_method.class": args["asr_search_method.class"], "search_method.params": args["asr_search_method.params"], } }, strategy=strategy, model=asr_model, task=asr_task, model_dir=args["asr_model_dir"], custom_dataset=audio_dataset) asr_generator.run() if hasattr(audio_dataset, "transcripts") and audio_dataset.transcripts is not None: asr_metric = asr_task.get_eval_metric(args, "asr_metric") with tf.io.gfile.GFile(asr_output_file, "r") as fp: metric_result = asr_metric([line.strip() for line in fp], audio_dataset.transcripts) logging.info("Evaluation Result of ASR:") for k, v in metric_result.items(): logging.info(" %s=%.2f", k, v) logging.info("Creating MT generator.") mt_reference_file = "ram://mt_reference_file" with tf.io.gfile.GFile(mt_reference_file, "w") as fw: for x in audio_dataset.targets: fw.write(x.strip() + "\n") with training_utils.get_strategy_scope(strategy): mt_generator = build_exp( { "class": SequenceGenerator, "params": { "output_file": args["mt_output_file"], "search_method.class": args["mt_search_method.class"], "search_method.params": args["mt_search_method.params"], "metric.class": args["mt_metric.class"], "metric.params": args["mt_metric.params"] } }, strategy=strategy, model=mt_model, task=mt_task, model_dir=args["mt_model_dir"], custom_dataset=build_dataset({ "class": ParallelTextDataset, "params": { "src_file": asr_output_file, "trg_file": mt_reference_file } })) mt_generator.run()