def from_params(cls, vocab: Vocabulary, params: Params) -> 'TokenCharactersEncoder': # type: ignore # pylint: disable=arguments-differ embedding_params: Params = params.pop("embedding") # Embedding.from_params() uses "tokens" as the default namespace, but we need to change # that to be "token_characters" by default. embedding_params.setdefault("vocab_namespace", "token_characters") embedding = Embedding.from_params(vocab, embedding_params) encoder_params: Params = params.pop("encoder") encoder = Seq2VecEncoder.from_params(encoder_params) dropout = params.pop_float("dropout", 0.0) params.assert_empty(cls.__name__) return cls(embedding, encoder, dropout)
def from_params( cls, vocab: Vocabulary, params: Params) -> 'BasicTextFieldEmbedder': # type: ignore # pylint: disable=arguments-differ,bad-super-call # The original `from_params` for this class was designed in a way that didn't agree # with the constructor. The constructor wants a 'token_embedders' parameter that is a # `Dict[str, TokenEmbedder]`, but the original `from_params` implementation expected those # key-value pairs to be top-level in the params object. # # This breaks our 'configuration wizard' and configuration checks. Hence, going forward, # the params need a 'token_embedders' key so that they line up with what the constructor wants. # For now, the old behavior is still supported, but produces a DeprecationWarning. embedder_to_indexer_map = params.pop("embedder_to_indexer_map", None) if embedder_to_indexer_map is not None: embedder_to_indexer_map = embedder_to_indexer_map.as_dict( quiet=True) allow_unmatched_keys = bool(params.pop("allow_unmatched_keys", False)) token_embedder_params = params.pop('token_embedders', None) if token_embedder_params is not None: # New way: explicitly specified, so use it. token_embedders = { name: Embedding.from_params(vocab=vocab, params=subparams) for name, subparams in token_embedder_params.items() } else: # Warn that the original behavior is deprecated warnings.warn( DeprecationWarning( "the token embedders for BasicTextFieldEmbedder should now " "be specified as a dict under the 'token_embedders' key, " "not as top-level key-value pairs")) token_embedders = {} keys = list(params.keys()) for key in keys: embedder_params = params.pop(key) token_embedders[key] = Embedding.from_params( vocab=vocab, params=embedder_params) # TODO(pitrack): replace this line? # params.assert_empty(cls.__name__) return cls(token_embedders, embedder_to_indexer_map, allow_unmatched_keys)
def load_archive(archive_file: str, device=None, weights_file: str = None) -> Archive: """ Instantiates an Archive from an archived `tar.gz` file. Parameters ---------- archive_file: ``str`` The archive file to load the model from. weights_file: ``str``, optional (default = None) The weights file to use. If unspecified, weights.th in the archive_file will be used. device: ``None`` or PyTorch device object. """ # redirect to the cache, if necessary resolved_archive_file = cached_path(archive_file) if resolved_archive_file == archive_file: logger.info(f"loading archive file {archive_file}") else: logger.info( f"loading archive file {archive_file} from cache at {resolved_archive_file}" ) tempdir = None if os.path.isdir(resolved_archive_file): serialization_dir = resolved_archive_file else: # Extract archive to temp dir tempdir = tempfile.mkdtemp() logger.info( f"extracting archive file {resolved_archive_file} to temp dir {tempdir}" ) with tarfile.open(resolved_archive_file, 'r:gz') as archive: archive.extractall(tempdir) serialization_dir = tempdir # Load config config = Params.from_file(os.path.join(serialization_dir, CONFIG_NAME)) config.loading_from_archive = True if weights_file: weights_path = weights_file else: weights_path = os.path.join(serialization_dir, _WEIGHTS_NAME) # Instantiate model. Use a duplicate of the config, as it will get consumed. model = Model.load(config, weights_file=weights_path, serialization_dir=serialization_dir, device=device) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) return Archive(model=model, config=config)
def create_serialization_dir(params: Params) -> None: """ This function creates the serialization directory if it doesn't exist. If it already exists and is non-empty, then it verifies that we're recovering from a training with an identical configuration. Parameters ---------- params: ``Params`` A parameter object specifying an AllenNLP Experiment. serialization_dir: ``str`` The directory in which to save results and logs. recover: ``bool`` If ``True``, we will try to recover from an existing serialization directory, and crash if the directory doesn't exist, or doesn't match the configuration we're given. """ serialization_dir = params['environment']['serialization_dir'] recover = params['environment']['recover'] if os.path.exists(serialization_dir) and os.listdir(serialization_dir): if not recover: raise ConfigurationError(f"Serialization directory ({serialization_dir}) already exists and is " f"not empty. Specify --recover to recover training from existing output.") logger.info(f"Recovering from prior training at {serialization_dir}.") recovered_config_file = os.path.join(serialization_dir, CONFIG_NAME) if not os.path.exists(recovered_config_file): raise ConfigurationError("The serialization directory already exists but doesn't " "contain a config.json. You probably gave the wrong directory.") else: loaded_params = Params.from_file(recovered_config_file) if params != loaded_params: raise ConfigurationError("Training configuration does not match the configuration we're " "recovering from.") # In the recover mode, we don't need to reload the pre-trained embeddings. remove_pretrained_embedding_params(params) else: if recover: raise ConfigurationError(f"--recover specified but serialization_dir ({serialization_dir}) " "does not exist. There is nothing to recover from.") os.makedirs(serialization_dir, exist_ok=True) params.to_file(os.path.join(serialization_dir, CONFIG_NAME))
def train_model(params: Params): """ Trains the model specified in the given :class:`Params` object, using the data and training parameters also specified in that object, and saves the results. Parameters ---------- params : ``Params`` A parameter object specifying an AllenNLP Experiment. Returns ------- best_model: ``Model`` The model with the best epoch weights. """ # Set up the environment. environment_params = params['environment'] environment.set_seed(environment_params) create_serialization_dir(params) environment.prepare_global_logging(environment_params) environment.check_for_gpu(environment_params) if environment_params['gpu']: device = torch.device('cuda:{}'.format(environment_params['cuda_device'])) environment.occupy_gpu(device) else: device = torch.device('cpu') params['trainer']['device'] = device # Load data. data_params = params['data'] dataset = dataset_from_params(data_params, universal_postags=params["model"].get('universal_postags',False), generator_source_copy=data_params.get('source_copy', True), multilingual=params['model'].get('multilingual',False), extra_check=params['data'].get('extra_check',False)) train_data = dataset['train'] dev_data = dataset.get('dev') test_data = dataset.get('test') train_mappings = dataset.get('train_mappings',None) train_replacements = dataset.get('train_replacements',None) # Vocabulary and iterator are created here. vocab_params = params.get('vocab', {}) if "fixed_vocab" in vocab_params and vocab_params["fixed_vocab"]: vocab = Vocabulary.from_files("data/vocabulary") else: vocab = Vocabulary.from_instances(instances=train_data, **vocab_params) # Initializing the model can have side effect of expanding the vocabulary vocab.save_to_files(os.path.join(environment_params['serialization_dir'], "vocabulary")) train_iterator, dev_iterater, test_iterater = iterator_from_params(vocab, data_params['iterator']) if train_mappings is not None and train_replacements is not None: with open(os.path.join(environment_params['serialization_dir'],"trns_lex_missing.json"),"w", encoding='utf-8') as outfile: json.dump(train_mappings[-1], outfile, indent=4, default=serialize_sets) with open(os.path.join(environment_params['serialization_dir'],"trns_lexicalizations.json"),"w", encoding='utf-8') as outfile: json.dump(train_mappings[-2], outfile, indent=4, default=serialize_sets) with open(os.path.join(environment_params['serialization_dir'],"trns_rep.json"), "w", encoding='utf-8') as outfile: json.dump(train_replacements, outfile, indent=4, default=serialize_sets) # Build the model. model_params = params['model'] model = getattr(Models, model_params['model_type']).from_params(vocab, model_params, environment_params['gpu'], train_mappings, train_replacements) logger.info(model) # Train trainer_params = params['trainer'] no_grad_regexes = trainer_params['no_grad'] for name, parameter in model.named_parameters(): if any(re.search(regex, name) for regex in no_grad_regexes): parameter.requires_grad_(False) frozen_parameter_names, tunable_parameter_names = \ environment.get_frozen_and_tunable_parameter_names(model) logger.info("Following parameters are Frozen (without gradient):") for name in frozen_parameter_names: logger.info(name) logger.info("Following parameters are Tunable (with gradient):") for name in tunable_parameter_names: logger.info(name) logger.info("Total nr of parameters Tunable (with gradient):") pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(pytorch_total_params) trainer = Trainer.from_params(model, train_data, dev_data, train_iterator, dev_iterater, trainer_params) serialization_dir = trainer_params['serialization_dir'] try: metrics = trainer.train() except KeyboardInterrupt: # if we have completed an epoch, try to create a model archive. if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)): logger.info("Training interrupted by the user. Attempting to create " "a model archive using the current best epoch weights.") archive_model(serialization_dir) raise # Now tar up results archive_model(serialization_dir) logger.info("Loading the best epoch weights.") best_model_state_path = os.path.join(serialization_dir, 'best.th') best_model_state = torch.load(best_model_state_path) best_model = model if not isinstance(best_model, torch.nn.DataParallel): best_model_state = {re.sub(r'^module\.', '', k):v for k, v in best_model_state.items()} best_model.load_state_dict(best_model_state) return best_model
# if we have completed an epoch, try to create a model archive. if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)): logger.info("Training interrupted by the user. Attempting to create " "a model archive using the current best epoch weights.") archive_model(serialization_dir) raise # Now tar up results archive_model(serialization_dir) logger.info("Loading the best epoch weights.") best_model_state_path = os.path.join(serialization_dir, 'best.th') best_model_state = torch.load(best_model_state_path) best_model = model if not isinstance(best_model, torch.nn.DataParallel): best_model_state = {re.sub(r'^module\.', '', k):v for k, v in best_model_state.items()} best_model.load_state_dict(best_model_state) return best_model if __name__ == "__main__": parser = argparse.ArgumentParser('train.py') parser.add_argument('params', help='Parameters YAML file.') args = parser.parse_args() params = Params.from_file(args.params) logger.info(params) train_model(params)
def create_kwargs(cls: Type[T], params: Params, **extras) -> Dict[str, Any]: """ Given some class, a `Params` object, and potentially other keyword arguments, create a dict of keyword args suitable for passing to the class's constructor. The function does this by finding the class's constructor, matching the constructor arguments to entries in the `params` object, and instantiating values for the parameters using the type annotation and possibly a from_params method. Any values that are provided in the `extras` will just be used as is. For instance, you might provide an existing `Vocabulary` this way. """ # Get the signature of the constructor. signature = inspect.signature(cls.__init__) kwargs: Dict[str, Any] = {} # Iterate over all the constructor parameters and their annotations. for name, param in signature.parameters.items(): # Skip "self". You're not *required* to call the first parameter "self", # so in theory this logic is fragile, but if you don't call the self parameter # "self" you kind of deserve what happens. if name == "self": continue # If the annotation is a compound type like typing.Dict[str, int], # it will have an __origin__ field indicating `typing.Dict` # and an __args__ field indicating `(str, int)`. We capture both. annotation = remove_optional(param.annotation) origin = getattr(annotation, '__origin__', None) args = getattr(annotation, '__args__', []) # The parameter is optional if its default value is not the "no default" sentinel. default = param.default optional = default != _NO_DEFAULT # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary. # We check the provided `extras` for these and just use them if they exist. if name in extras: kwargs[name] = extras[name] # The next case is when the parameter type is itself constructible from_params. elif hasattr(annotation, 'from_params'): if name in params: # Our params have an entry for this, so we use that. subparams = params.pop(name) if takes_arg(annotation.from_params, 'extras'): # If annotation.params accepts **extras, we need to pass them all along. # For example, `BasicTextFieldEmbedder.from_params` requires a Vocabulary # object, but `TextFieldEmbedder.from_params` does not. subextras = extras else: # Otherwise, only supply the ones that are actual args; any additional ones # will cause a TypeError. subextras = { k: v for k, v in extras.items() if takes_arg(annotation.from_params, k) } # In some cases we allow a string instead of a param dict, so # we need to handle that case separately. if isinstance(subparams, str): kwargs[name] = annotation.by_name(subparams)() else: kwargs[name] = annotation.from_params(params=subparams, **subextras) elif not optional: # Not optional and not supplied, that's an error! raise ConfigurationError( f"expected key {name} for {cls.__name__}") else: kwargs[name] = default # If the parameter type is a Python primitive, just pop it off # using the correct casting pop_xyz operation. elif annotation == str: kwargs[name] = (params.pop(name, default) if optional else params.pop(name)) elif annotation == int: kwargs[name] = (params.pop_int(name, default) if optional else params.pop_int(name)) elif annotation == bool: kwargs[name] = (params.pop_bool(name, default) if optional else params.pop_bool(name)) elif annotation == float: kwargs[name] = (params.pop_float(name, default) if optional else params.pop_float(name)) # This is special logic for handling types like Dict[str, TokenIndexer], which it creates by # instantiating each value from_params and returning the resulting dict. elif origin in (Dict, dict) and len(args) == 2 and hasattr( args[-1], 'from_params'): value_cls = annotation.__args__[-1] value_dict = {} for key, value_params in params.pop(name, Params({})).items(): value_dict[key] = value_cls.from_params(params=value_params, **extras) kwargs[name] = value_dict else: # Pass it on as is and hope for the best. ¯\_(ツ)_/¯ if optional: kwargs[name] = params.pop(name, default) else: kwargs[name] = params.pop(name) params.assert_empty(cls.__name__) return kwargs
def from_params(cls: Type[T], params: Params, **extras) -> T: """ This is the automatic implementation of `from_params`. Any class that subclasses `FromParams` (or `Registrable`, which itself subclasses `FromParams`) gets this implementation for free. If you want your class to be instantiated from params in the "obvious" way -- pop off parameters and hand them to your constructor with the same names -- this provides that functionality. If you need more complex logic in your from `from_params` method, you'll have to implement your own method that overrides this one. """ # pylint: disable=protected-access from xlamr_stog.utils.registrable import Registrable # import here to avoid circular imports logger.info( f"instantiating class {cls} from params {getattr(params, 'params', params)} " f"and extras {extras}") if params is None: return None registered_subclasses = Registrable._registry.get(cls) if registered_subclasses is not None: # We know ``cls`` inherits from Registrable, so we'll use a cast to make mypy happy. # We have to use a disable to make pylint happy. # pylint: disable=no-member as_registrable = cast(Type[Registrable], cls) default_to_first_choice = as_registrable.default_implementation is not None choice = params.pop_choice( "type", choices=as_registrable.list_available(), default_to_first_choice=default_to_first_choice) subclass = registered_subclasses[choice] # We want to call subclass.from_params. It's possible that it's just the "free" # implementation here, in which case it accepts `**extras` and we are not able # to make any assumptions about what extra parameters it needs. # # It's also possible that it has a custom `from_params` method. In that case it # won't accept any **extra parameters and we'll need to filter them out. if not takes_arg(subclass.from_params, 'extras'): # Necessarily subclass.from_params is a custom implementation, so we need to # pass it only the args it's expecting. extras = { k: v for k, v in extras.items() if takes_arg(subclass.from_params, k) } return subclass.from_params(params=params, **extras) else: # This is not a base class, so convert our params and extras into a dict of kwargs. if cls.__init__ == object.__init__: # This class does not have an explicit constructor, so don't give it any kwargs. # Without this logic, create_kwargs will look at object.__init__ and see that # it takes *args and **kwargs and look for those. kwargs: Dict[str, Any] = {} else: # This class has a constructor, so create kwargs for it. kwargs = create_kwargs(cls, params, **extras) return cls(**kwargs) # type: ignore