def configure(data_dir, run_dir, config, override_values, xm_parameters=None): """Sets up the Learned Interpreter code with the specified configuration.""" seed() # Apply any overrides set at the command line or in the launcher. if config.overrides != config.default_overrides: logging.info('Applying overrides set at command line: %s', config.overrides) overrides_lib.apply_overrides( config, override_names=config.overrides.split(',')) config.update_from_flattened_dict(override_values) # If a checkpoint is specified, it determines the "original run." # Otherwise the run_dir, if already present, determines the "original run." config_filepath = os.path.join(run_dir, 'config.json') if checkpoint_utils.is_checkpoint_specified(config.checkpoint): original_checkpoint_path = checkpoint_utils.get_specified_checkpoint_path( run_dir, config.checkpoint) original_run_dir = checkpoint_utils.get_run_dir( original_checkpoint_path) original_config_filepath = os.path.join(original_run_dir, 'config.json') else: checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir) original_checkpoint_path = checkpoint_utils.latest_checkpoint( checkpoint_dir) original_config_filepath = config_filepath original_config_exists = gfile.exists(original_config_filepath) # Handle any existing configs. if original_config_exists: original_config = config_utils.load_config(original_config_filepath) # Handle the model config. if config.runner.model_config == 'load': logging.info('Loading the model config from %s', original_config_filepath) config.model.update(original_config.model) config.dataset.representation = original_config.dataset.representation elif config.runner.model_config == 'assert': same_config = config_utils.equals(config.model, original_config.model) # Resolution: # Either use a new run_dir, or set model_config to 'load' or 'keep'. assert same_config, 'Model config has changed.' else: assert config.runner.model_config == 'keep' # Handle the dataset config. if config.runner.dataset_config == 'load': logging.info('Loading the data config from %s', original_config_filepath) config.dataset.update(original_config.dataset) elif config.runner.dataset_config == 'assert': same_config = config_utils.equals(config.dataset, original_config.dataset) assert same_config, 'Dataset config has changed.' else: assert config.runner.dataset_config == 'keep' elif (config.runner.model_config == 'load' or config.runner.dataset_config == 'load'): raise ValueError('Original model config not found.') # In interactive mode, force batch size 1. if config.runner.mode == 'interact': config.dataset.batch_size = 1 config_exists = gfile.exists(config_filepath) if not config_exists and config.runner.mode in 'train': gfile.makedirs(run_dir) config_utils.save_config(config, config_filepath) # Load dataset. if config.setup.setup_dataset: dataset_info = dataset_utils.get_dataset(data_dir, config) info = dataset_info.info else: dataset_info = None info = None # Create model. if config.setup.setup_model: model = models_lib.get_model(info, config) else: model = None adapter = adapters_lib.get_default_adapter(info, config) return RunConfiguration(mode=config.runner.mode, method=config.runner.method, run_dir=run_dir, data_dir=data_dir, original_checkpoint_path=original_checkpoint_path, model=model, info=info, config=config, adapter=adapter, dataset_info=dataset_info)
def analyze_once(run_configuration): """Analyzes the existing model checkpoint. Runs inference for the model for each of the evaluation datasets, and writes the metrics to disk as JSON. Args: run_configuration: The setup.RunConfiguration for the run. """ analysis_run_dir = run_configuration.run_dir original_checkpoint_path = run_configuration.original_checkpoint_path config = run_configuration.config data_dir = run_configuration.data_dir adapter = run_configuration.adapter optimizer = adapter.create_optimizer(run_configuration) original_run_dir = checkpoint_utils.get_run_dir(original_checkpoint_path) original_run_name = os.path.basename(original_run_dir) original_config_path = os.path.join(original_run_dir, 'config.json') checkpoint_basename = os.path.basename(original_checkpoint_path) analysis_dir = os.path.join(analysis_run_dir, checkpoint_basename) analysis_file = os.path.join(analysis_dir, 'data.json') checkpoint_dir = checkpoint_utils.build_checkpoint_dir(analysis_dir) checkpoint_path = os.path.join(checkpoint_dir, checkpoint_basename) gfile.makedirs(checkpoint_dir) gfile.makedirs(analysis_dir) # Save a copy of the original checkpoint. gfile.copy(original_checkpoint_path, checkpoint_path) # Save analysis metadata. metadata = { 'name': original_run_name, 'run_dir': original_run_dir, 'timestamp': datetime.datetime.now().timestamp(), } metadata_path = os.path.join(analysis_dir, 'metadata.json') with gfile.GFile(metadata_path, 'w') as f: f.write(json.dumps(metadata)) # Save a copy of the original config. new_config_path = os.path.join(analysis_run_dir, 'config.json') # We set overwrite=True to handle preemption. gfile.copy(original_config_path, new_config_path, overwrite=True) logging.info('Saving results to %s', analysis_file) # Load the datasets to analyze. analysis_results = [] eval_dataset_names = config.launcher.eval_dataset_names.split(',') for dataset_name in eval_dataset_names: logging.info('Evaluating with checkpoint_path: %s', checkpoint_path) logging.info('Evaluating on dataset: %s', dataset_name) dataset_info = dataset_utils.get_dataset(data_dir, config, dataset_name) run_configuration = dataclasses.replace( run_configuration, run_dir=analysis_dir, original_checkpoint_path=checkpoint_path, info=dataset_info.info, dataset_info=dataset_info, ) metrics = workflows.predict_once(run_configuration, optimizer) logging.info('Done evaluating on dataset: %s', dataset_name) results = { 'dataset_name': dataset_name, 'accuracy': metrics['accuracy'].tolist(), 'denominator': metrics['denominator'].tolist(), } analysis_results.append(results) with gfile.GFile(analysis_file, 'wb') as f: json.dump(analysis_results, f)