def from_params( cls, vocab: Vocabulary, params: Params) -> 'TokenCharactersEncoder': # type: ignore # pylint: disable=arguments-differ embedding_params: Params = params.pop("embedding") # Embedding.from_params() uses "tokens" as the default namespace, but we need to change # that to be "token_characters" by default. embedding_params.setdefault("vocab_namespace", "token_characters") embedding = Embedding.from_params(vocab, embedding_params) encoder_params: Params = params.pop("encoder") encoder = Seq2VecEncoder.from_params(encoder_params) dropout = params.pop_float("dropout", 0.0) params.assert_empty(cls.__name__) return cls(embedding, encoder, dropout)
def extend_from_instances(self, params: Params, instances: Iterable['adi.Instance'] = ()) -> None: """ Extends an already generated vocabulary using a collection of instances. """ min_count = params.pop("min_count", None) max_vocab_size = pop_max_vocab_size(params) non_padded_namespaces = params.pop("non_padded_namespaces", DEFAULT_NON_PADDED_NAMESPACES) pretrained_files = params.pop("pretrained_files", {}) min_pretrained_embeddings = params.pop("min_pretrained_embeddings", None) only_include_pretrained_words = params.pop_bool("only_include_pretrained_words", False) tokens_to_add = params.pop("tokens_to_add", None) params.assert_empty("Vocabulary - from dataset") logger.info("Fitting token dictionary from dataset.") namespace_token_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) for instance in Tqdm.tqdm(instances): instance.count_vocab_items(namespace_token_counts) self._extend(counter=namespace_token_counts, min_count=min_count, max_vocab_size=max_vocab_size, non_padded_namespaces=non_padded_namespaces, pretrained_files=pretrained_files, only_include_pretrained_words=only_include_pretrained_words, tokens_to_add=tokens_to_add, min_pretrained_embeddings=min_pretrained_embeddings)
def from_params( cls, vocab: Vocabulary, params: Params) -> 'BasicTextFieldEmbedder': # type: ignore # pylint: disable=arguments-differ,bad-super-call # The original `from_params` for this class was designed in a way that didn't agree # with the constructor. The constructor wants a 'token_embedders' parameter that is a # `Dict[str, TokenEmbedder]`, but the original `from_params` implementation expected those # key-value pairs to be top-level in the params object. # # This breaks our 'configuration wizard' and configuration checks. Hence, going forward, # the params need a 'token_embedders' key so that they line up with what the constructor wants. # For now, the old behavior is still supported, but produces a DeprecationWarning. embedder_to_indexer_map = params.pop("embedder_to_indexer_map", None) if embedder_to_indexer_map is not None: embedder_to_indexer_map = embedder_to_indexer_map.as_dict( quiet=True) allow_unmatched_keys = bool(params.pop("allow_unmatched_keys", False)) token_embedder_params = params.pop('token_embedders', None) if token_embedder_params is not None: # New way: explicitly specified, so use it. token_embedders = { name: Embedding.from_params(vocab=vocab, params=subparams) for name, subparams in token_embedder_params.items() } else: # Warn that the original behavior is deprecated warnings.warn( DeprecationWarning( "the token embedders for BasicTextFieldEmbedder should now " "be specified as a dict under the 'token_embedders' key, " "not as top-level key-value pairs")) token_embedders = {} keys = list(params.keys()) for key in keys: embedder_params = params.pop(key) token_embedders[key] = Embedding.from_params( vocab=vocab, params=embedder_params) # TODO(pitrack): replace this line? # params.assert_empty(cls.__name__) return cls(token_embedders, embedder_to_indexer_map, allow_unmatched_keys)
def create_serialization_dir(params: Params) -> None: """ This function creates the serialization directory if it doesn't exist. If it already exists and is non-empty, then it verifies that we're recovering from a training with an identical configuration. Parameters ---------- params: ``Params`` A parameter object specifying an AllenNLP Experiment. serialization_dir: ``str`` The directory in which to save results and logs. recover: ``bool`` If ``True``, we will try to recover from an existing serialization directory, and crash if the directory doesn't exist, or doesn't match the configuration we're given. """ serialization_dir = params['environment']['serialization_dir'] recover = params['environment']['recover'] if os.path.exists(serialization_dir) and os.listdir(serialization_dir): if not recover: raise ConfigurationError( f"Serialization directory ({serialization_dir}) already exists and is " f"not empty. Specify --recover to recover training from existing output." ) logger.info(f"Recovering from prior training at {serialization_dir}.") recovered_config_file = os.path.join(serialization_dir, CONFIG_NAME) if not os.path.exists(recovered_config_file): raise ConfigurationError( "The serialization directory already exists but doesn't " "contain a config.json. You probably gave the wrong directory." ) else: loaded_params = Params.from_file(recovered_config_file) if params != loaded_params: raise ConfigurationError( "Training configuration does not match the configuration we're " "recovering from.") # In the recover mode, we don't need to reload the pre-trained embeddings. remove_pretrained_embedding_params(params) else: if recover: raise ConfigurationError( f"--recover specified but serialization_dir ({serialization_dir}) " "does not exist. There is nothing to recover from.") os.makedirs(serialization_dir, exist_ok=True) params.to_file(os.path.join(serialization_dir, CONFIG_NAME))
def load_archive(archive_file: str, device=None, weights_file: str = None) -> Archive: """ Instantiates an Archive from an archived `tar.gz` file. Parameters ---------- archive_file: ``str`` The archive file to load the model from. weights_file: ``str``, optional (default = None) The weights file to use. If unspecified, weights.th in the archive_file will be used. device: ``None`` or PyTorch device object. """ # redirect to the cache, if necessary resolved_archive_file = cached_path(archive_file) if resolved_archive_file == archive_file: logger.info(f"loading archive file {archive_file}") else: logger.info( f"loading archive file {archive_file} from cache at {resolved_archive_file}" ) tempdir = None if os.path.isdir(resolved_archive_file): serialization_dir = resolved_archive_file else: # Extract archive to temp dir tempdir = tempfile.mkdtemp() logger.info( f"extracting archive file {resolved_archive_file} to temp dir {tempdir}" ) with tarfile.open(resolved_archive_file, 'r:gz') as archive: archive.extractall(tempdir) serialization_dir = tempdir # Load config config = Params.from_file(os.path.join(serialization_dir, CONFIG_NAME)) config.loading_from_archive = True if weights_file: weights_path = weights_file else: weights_path = os.path.join(serialization_dir, _WEIGHTS_NAME) # Instantiate model. Use a duplicate of the config, as it will get consumed. model = Model.load(config, weights_file=weights_path, serialization_dir=serialization_dir, device=device) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) return Archive(model=model, config=config)
def pop_max_vocab_size(params: Params) -> Union[int, Dict[str, int]]: """ max_vocab_size is allowed to be either an int or a Dict[str, int] (or nothing). But it could also be a string representing an int (in the case of environment variable substitution). So we need some complex logic to handle it. """ size = params.pop("max_vocab_size", None) if isinstance(size, Params): # This is the Dict[str, int] case. return size.as_dict() elif size is not None: # This is the int / str case. try: return int(size) except: return size else: return None
def train_model(params: Params): """ Trains the model specified in the given :class:`Params` object, using the data and training parameters also specified in that object, and saves the results. Parameters ---------- params : ``Params`` A parameter object specifying an AllenNLP Experiment. Returns ------- best_model: ``Model`` The model with the best epoch weights. """ # Set up the environment. environment_params = params['environment'] environment.set_seed(environment_params) create_serialization_dir(params) environment.prepare_global_logging(environment_params) #environment.check_for_gpu(environment_params) #if environment_params['gpu']: # device = torch.device('cuda:{}'.format(environment_params['cuda_device'])) # environment.occupy_gpu(device) #else: # device = torch.device('cpu') device = torch.device('cpu') params['trainer']['device'] = device # Load data. data_params = params['data'] dataset = dataset_from_params(data_params) train_data = dataset['train'] dev_data = dataset.get('dev') test_data = dataset.get('test') # Vocabulary and iterator are created here. vocab_params = params.get('vocab', {}) vocab = Vocabulary.from_instances(instances=train_data, **vocab_params) # Initializing the model can have side effect of expanding the vocabulary vocab.save_to_files(os.path.join(environment_params['serialization_dir'], "vocabulary")) train_iterator, dev_iterater, test_iterater = iterator_from_params(vocab, data_params['iterator']) # Build the model. model_params = params['model'] model = getattr(Models, model_params['model_type']).from_params(vocab, model_params) logger.info(model) # Train trainer_params = params['trainer'] no_grad_regexes = trainer_params['no_grad'] for name, parameter in model.named_parameters(): if any(re.search(regex, name) for regex in no_grad_regexes): parameter.requires_grad_(False) frozen_parameter_names, tunable_parameter_names = \ environment.get_frozen_and_tunable_parameter_names(model) logger.info("Following parameters are Frozen (without gradient):") for name in frozen_parameter_names: logger.info(name) logger.info("Following parameters are Tunable (with gradient):") for name in tunable_parameter_names: logger.info(name) trainer = Trainer.from_params(model, train_data, dev_data, train_iterator, dev_iterater, trainer_params) serialization_dir = trainer_params['serialization_dir'] try: metrics = trainer.train() except KeyboardInterrupt: # if we have completed an epoch, try to create a model archive. if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)): logger.info("Training interrupted by the user. Attempting to create " "a model archive using the current best epoch weights.") archive_model(serialization_dir) raise # Now tar up results archive_model(serialization_dir) logger.info("Loading the best epoch weights.") best_model_state_path = os.path.join(serialization_dir, 'best.th') best_model_state = torch.load(best_model_state_path) best_model = model if not isinstance(best_model, torch.nn.DataParallel): best_model_state = {re.sub(r'^module\.', '', k):v for k, v in best_model_state.items()} best_model.load_state_dict(best_model_state) return best_model
# if we have completed an epoch, try to create a model archive. if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)): logger.info("Training interrupted by the user. Attempting to create " "a model archive using the current best epoch weights.") archive_model(serialization_dir) raise # Now tar up results archive_model(serialization_dir) logger.info("Loading the best epoch weights.") best_model_state_path = os.path.join(serialization_dir, 'best.th') best_model_state = torch.load(best_model_state_path) best_model = model if not isinstance(best_model, torch.nn.DataParallel): best_model_state = {re.sub(r'^module\.', '', k):v for k, v in best_model_state.items()} best_model.load_state_dict(best_model_state) return best_model if __name__ == "__main__": parser = argparse.ArgumentParser('train.py') parser.add_argument('params', help='Parameters YAML file.') args = parser.parse_args() params = Params.from_file(args.params) logger.info(params) train_model(params)
def create_kwargs(cls: Type[T], params: Params, **extras) -> Dict[str, Any]: """ Given some class, a `Params` object, and potentially other keyword arguments, create a dict of keyword args suitable for passing to the class's constructor. The function does this by finding the class's constructor, matching the constructor arguments to entries in the `params` object, and instantiating values for the parameters using the type annotation and possibly a from_params method. Any values that are provided in the `extras` will just be used as is. For instance, you might provide an existing `Vocabulary` this way. """ # Get the signature of the constructor. signature = inspect.signature(cls.__init__) kwargs: Dict[str, Any] = {} # Iterate over all the constructor parameters and their annotations. for name, param in signature.parameters.items(): # Skip "self". You're not *required* to call the first parameter "self", # so in theory this logic is fragile, but if you don't call the self parameter # "self" you kind of deserve what happens. if name == "self": continue # If the annotation is a compound type like typing.Dict[str, int], # it will have an __origin__ field indicating `typing.Dict` # and an __args__ field indicating `(str, int)`. We capture both. annotation = remove_optional(param.annotation) origin = getattr(annotation, '__origin__', None) args = getattr(annotation, '__args__', []) # The parameter is optional if its default value is not the "no default" sentinel. default = param.default optional = default != _NO_DEFAULT # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary. # We check the provided `extras` for these and just use them if they exist. if name in extras: kwargs[name] = extras[name] # The next case is when the parameter type is itself constructible from_params. elif hasattr(annotation, 'from_params'): if name in params: # Our params have an entry for this, so we use that. subparams = params.pop(name) if takes_arg(annotation.from_params, 'extras'): # If annotation.params accepts **extras, we need to pass them all along. # For example, `BasicTextFieldEmbedder.from_params` requires a Vocabulary # object, but `TextFieldEmbedder.from_params` does not. subextras = extras else: # Otherwise, only supply the ones that are actual args; any additional ones # will cause a TypeError. subextras = { k: v for k, v in extras.items() if takes_arg(annotation.from_params, k) } # In some cases we allow a string instead of a param dict, so # we need to handle that case separately. if isinstance(subparams, str): kwargs[name] = annotation.by_name(subparams)() else: kwargs[name] = annotation.from_params(params=subparams, **subextras) elif not optional: # Not optional and not supplied, that's an error! raise ConfigurationError( f"expected key {name} for {cls.__name__}") else: kwargs[name] = default # If the parameter type is a Python primitive, just pop it off # using the correct casting pop_xyz operation. elif annotation == str: kwargs[name] = (params.pop(name, default) if optional else params.pop(name)) elif annotation == int: kwargs[name] = (params.pop_int(name, default) if optional else params.pop_int(name)) elif annotation == bool: kwargs[name] = (params.pop_bool(name, default) if optional else params.pop_bool(name)) elif annotation == float: kwargs[name] = (params.pop_float(name, default) if optional else params.pop_float(name)) # This is special logic for handling types like Dict[str, TokenIndexer], which it creates by # instantiating each value from_params and returning the resulting dict. elif origin in (Dict, dict) and len(args) == 2 and hasattr( args[-1], 'from_params'): value_cls = annotation.__args__[-1] value_dict = {} for key, value_params in params.pop(name, Params({})).items(): value_dict[key] = value_cls.from_params(params=value_params, **extras) kwargs[name] = value_dict else: # Pass it on as is and hope for the best. ¯\_(ツ)_/¯ if optional: kwargs[name] = params.pop(name, default) else: kwargs[name] = params.pop(name) params.assert_empty(cls.__name__) return kwargs
def from_params(cls: Type[T], params: Params, **extras) -> T: """ This is the automatic implementation of `from_params`. Any class that subclasses `FromParams` (or `Registrable`, which itself subclasses `FromParams`) gets this implementation for free. If you want your class to be instantiated from params in the "obvious" way -- pop off parameters and hand them to your constructor with the same names -- this provides that functionality. If you need more complex logic in your from `from_params` method, you'll have to implement your own method that overrides this one. """ # pylint: disable=protected-access from stog.utils.registrable import Registrable # import here to avoid circular imports logger.info( f"instantiating class {cls} from params {getattr(params, 'params', params)} " f"and extras {extras}") if params is None: return None registered_subclasses = Registrable._registry.get(cls) if registered_subclasses is not None: # We know ``cls`` inherits from Registrable, so we'll use a cast to make mypy happy. # We have to use a disable to make pylint happy. # pylint: disable=no-member as_registrable = cast(Type[Registrable], cls) default_to_first_choice = as_registrable.default_implementation is not None choice = params.pop_choice( "type", choices=as_registrable.list_available(), default_to_first_choice=default_to_first_choice) subclass = registered_subclasses[choice] # We want to call subclass.from_params. It's possible that it's just the "free" # implementation here, in which case it accepts `**extras` and we are not able # to make any assumptions about what extra parameters it needs. # # It's also possible that it has a custom `from_params` method. In that case it # won't accept any **extra parameters and we'll need to filter them out. if not takes_arg(subclass.from_params, 'extras'): # Necessarily subclass.from_params is a custom implementation, so we need to # pass it only the args it's expecting. extras = { k: v for k, v in extras.items() if takes_arg(subclass.from_params, k) } return subclass.from_params(params=params, **extras) else: # This is not a base class, so convert our params and extras into a dict of kwargs. if cls.__init__ == object.__init__: # This class does not have an explicit constructor, so don't give it any kwargs. # Without this logic, create_kwargs will look at object.__init__ and see that # it takes *args and **kwargs and look for those. kwargs: Dict[str, Any] = {} else: # This class has a constructor, so create kwargs for it. kwargs = create_kwargs(cls, params, **extras) return cls(**kwargs) # type: ignore
def from_params(cls, params: Params, instances: Iterable['adi.Instance'] = None): # type: ignore """ There are two possible ways to build a vocabulary; from a collection of instances, using :func:`Vocabulary.from_instances`, or from a pre-saved vocabulary, using :func:`Vocabulary.from_files`. You can also extend pre-saved vocabulary with collection of instances using this method. This method wraps these options, allowing their specification from a ``Params`` object, generated from a JSON configuration file. Parameters ---------- params: Params, required. instances: Iterable['adi.Instance'], optional If ``params`` doesn't contain a ``directory_path`` key, the ``Vocabulary`` can be built directly from a collection of instances (i.e. a dataset). If ``extend`` key is set False, dataset instances will be ignored and final vocabulary will be one loaded from ``directory_path``. If ``extend`` key is set True, dataset instances will be used to extend the vocabulary loaded from ``directory_path`` and that will be final vocabulary used. Returns ------- A ``Vocabulary``. """ # pylint: disable=arguments-differ # Vocabulary is ``Registrable`` so that you can configure a custom subclass, # but (unlike most of our registrables) almost everyone will want to use the # base implementation. So instead of having an abstract ``VocabularyBase`` or # such, we just add the logic for instantiating a registered subclass here, # so that most users can continue doing what they were doing. vocab_type = params.pop("type", None) if vocab_type is not None: return cls.by_name(vocab_type).from_params(params=params, instances=instances) extend = params.pop("extend", False) vocabulary_directory = params.pop("directory_path", None) if not vocabulary_directory and not instances: raise ConfigurationError("You must provide either a Params object containing a " "vocab_directory key or a Dataset to build a vocabulary from.") if extend and not instances: raise ConfigurationError("'extend' is true but there are not instances passed to extend.") if extend and not vocabulary_directory: raise ConfigurationError("'extend' is true but there is not 'directory_path' to extend from.") if vocabulary_directory and instances: if extend: logger.info("Loading Vocab from files and extending it with dataset.") else: logger.info("Loading Vocab from files instead of dataset.") if vocabulary_directory: vocab = Vocabulary.from_files(vocabulary_directory) if not extend: params.assert_empty("Vocabulary - from files") return vocab if extend: vocab.extend_from_instances(params, instances=instances) return vocab min_count = params.pop("min_count", None) max_vocab_size = pop_max_vocab_size(params) non_padded_namespaces = params.pop("non_padded_namespaces", DEFAULT_NON_PADDED_NAMESPACES) pretrained_files = params.pop("pretrained_files", {}) min_pretrained_embeddings = params.pop("min_pretrained_embeddings", None) only_include_pretrained_words = params.pop_bool("only_include_pretrained_words", False) tokens_to_add = params.pop("tokens_to_add", None) params.assert_empty("Vocabulary - from dataset") return Vocabulary.from_instances(instances=instances, min_count=min_count, max_vocab_size=max_vocab_size, non_padded_namespaces=non_padded_namespaces, pretrained_files=pretrained_files, only_include_pretrained_words=only_include_pretrained_words, tokens_to_add=tokens_to_add, min_pretrained_embeddings=min_pretrained_embeddings)