def test_api(temp_output_dir, train_opts, nuqe_opts, atol): from kiwi import train, load_model train_opts.model = 'nuqe' train_opts.checkpoint_keep_only_best = 1 all_opts = merge_namespaces(train_opts, nuqe_opts) config_file = Path(temp_output_dir, 'config.yaml') save_config_file(all_opts, config_file) train_run_info = train(config_file) predicter = load_model(train_run_info.model_path) examples = { constants.SOURCE: open(nuqe_opts.test_source).readlines(), constants.TARGET: open(nuqe_opts.test_target).readlines(), constants.ALIGNMENTS: open(nuqe_opts.test_alignments).readlines(), } predictions = predicter.predict(examples, batch_size=train_opts.batch_size) predictions = predictions[constants.TARGET_TAGS] avg_of_avgs = np.mean(list(map(np.mean, predictions))) max_prob = max(map(max, predictions)) min_prob = min(map(min, predictions)) np.testing.assert_allclose(avg_of_avgs, 0.572441, atol=atol) assert 0 <= min_prob <= avg_of_avgs <= max_prob <= 1
def train_from_options(options): """ Runs the entire training pipeline using the configuration options received. These options include the pipeline and model options plus the model's API. Args: options (Namespace): All the configuration options retrieved from either a config file or input flags and the model being used. """ if options is None: return pipeline_options = options.pipeline model_options = options.model ModelClass = options.model_api tracking_run = tracking_logger.configure( run_uuid=pipeline_options.run_uuid, experiment_name=pipeline_options.experiment_name, run_name=pipeline_options.run_name, tracking_uri=pipeline_options.mlflow_tracking_uri, always_log_artifacts=pipeline_options.mlflow_always_log_artifacts, ) with tracking_run: output_dir = setup( output_dir=pipeline_options.output_dir, seed=pipeline_options.seed, gpu_id=pipeline_options.gpu_id, debug=pipeline_options.debug, quiet=pipeline_options.quiet, ) all_options = merge_namespaces(pipeline_options, model_options) log( output_dir, config_options=vars(all_options), save_config=pipeline_options.save_config, ) trainer = run(ModelClass, output_dir, pipeline_options, model_options) train_info = TrainRunInfo(trainer) teardown(pipeline_options) return train_info
def run_from_options(options): if options is None: return meta_options = options.meta pipeline_options = options.pipeline.pipeline model_options = options.pipeline.model ModelClass = options.pipeline.model_api tracking_run = tracking_logger.configure( run_uuid=pipeline_options.run_uuid, experiment_name=pipeline_options.experiment_name, run_name=pipeline_options.run_name, tracking_uri=pipeline_options.mlflow_tracking_uri, always_log_artifacts=pipeline_options.mlflow_always_log_artifacts, ) with tracking_run: output_dir = train.setup( output_dir=pipeline_options.output_dir, debug=pipeline_options.debug, quiet=pipeline_options.quiet, ) all_options = merge_namespaces(meta_options, pipeline_options, model_options) train.log( output_dir, config_options=vars(all_options), config_file_name='jackknife_config.yml', ) run( ModelClass, output_dir, pipeline_options, model_options, splits=meta_options.splits, ) teardown(pipeline_options)
def parse(self, args): if len(args) == 1 and args[0] in ['-h', '--help']: self._parser.print_help() return None # Parse train pipeline options pipeline_options, extra_args = self._parser.parse_known_args(args) config_option, _ = self._config_option_parser.parse_known_args(args) options = Namespace() options.pipeline = pipeline_options options.model = None options.model_api = None # Parse specific model options if there are model parsers if self._models is not None: if pipeline_options.model not in self._models: raise KeyError('Invalid model: {}'.format( pipeline_options.model)) if config_option: extra_args = ['--config', config_option.config] + extra_args # Check if there are model parsers model_parser = self._models[pipeline_options.model] model_options, remaining_args = model_parser.parse_known_args( extra_args) options.model = model_options # Retrieve the respective API for the selected model options.model_api = model_parser.api_module else: remaining_args = extra_args options.all_options = merge_namespaces(options.pipeline, options.model) if remaining_args: raise KeyError('Unrecognized options: {}'.format(remaining_args)) return options