def dry_run_from_params(params: Params, serialization_dir: str) -> None: prepare_environment(params) vocab_params = params.pop("vocabulary", {}) os.makedirs(serialization_dir, exist_ok=True) vocab_dir = os.path.join(serialization_dir, "vocabulary") if os.path.isdir(vocab_dir) and os.listdir(vocab_dir) is not None: raise ConfigurationError( "The 'vocabulary' directory in the provided serialization directory is non-empty" ) all_datasets = datasets_from_params(params) datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets)) for dataset in datasets_for_vocab_creation: if dataset not in all_datasets: raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}") logger.info( "From dataset instances, %s will be considered for vocabulary creation.", ", ".join(datasets_for_vocab_creation), ) instances = [ instance for key, dataset in all_datasets.items() for instance in dataset if key in datasets_for_vocab_creation ] vocab = Vocabulary.from_params(vocab_params, instances=instances) dataset = Batch(instances) dataset.index_instances(vocab) dataset.print_statistics() vocab.print_statistics() logger.info(f"writing the vocabulary to {vocab_dir}.") vocab.save_to_files(vocab_dir) model = Model.from_params(vocab=vocab, params=params.pop("model")) trainer_params = params.pop("trainer") no_grad_regexes = trainer_params.pop("no_grad", ()) for name, parameter in model.named_parameters(): if any(re.search(regex, name) for regex in no_grad_regexes): parameter.requires_grad_(False) log_frozen_and_tunable_parameter_names(model)
def make_vocab_from_params( params: Params, serialization_dir: str, print_statistics: bool = False ) -> Vocabulary: vocab_params = params.pop("vocabulary", {}) os.makedirs(serialization_dir, exist_ok=True) vocab_dir = os.path.join(serialization_dir, "vocabulary") if os.path.isdir(vocab_dir) and os.listdir(vocab_dir) is not None: raise ConfigurationError( "The 'vocabulary' directory in the provided serialization directory is non-empty" ) all_datasets = datasets_from_params(params) datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets)) for dataset in datasets_for_vocab_creation: if dataset not in all_datasets: raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}") logger.info( "From dataset instances, %s will be considered for vocabulary creation.", ", ".join(datasets_for_vocab_creation), ) instances: Iterable[Instance] = ( instance for key, dataset in all_datasets.items() if key in datasets_for_vocab_creation for instance in dataset ) if print_statistics: instances = list(instances) vocab = Vocabulary.from_params(vocab_params, instances=instances) logger.info(f"writing the vocabulary to {vocab_dir}.") vocab.save_to_files(vocab_dir) logger.info("done creating vocab") if print_statistics: dataset = Batch(instances) dataset.index_instances(vocab) dataset.print_statistics() vocab.print_statistics() return vocab