コード例 #1
0
 def from_params(cls, vocab: Vocabulary, params: Params) -> 'TokenCharactersEncoder':  # type: ignore
     # pylint: disable=arguments-differ
     embedding_params: Params = params.pop("embedding")
     # Embedding.from_params() uses "tokens" as the default namespace, but we need to change
     # that to be "token_characters" by default.
     embedding_params.setdefault("vocab_namespace", "token_characters")
     embedding = Embedding.from_params(vocab, embedding_params)
     encoder_params: Params = params.pop("encoder")
     encoder = Seq2VecEncoder.from_params(encoder_params)
     dropout = params.pop_float("dropout", 0.0)
     params.assert_empty(cls.__name__)
     return cls(embedding, encoder, dropout)
コード例 #2
0
    def from_params(
            cls, vocab: Vocabulary,
            params: Params) -> 'BasicTextFieldEmbedder':  # type: ignore
        # pylint: disable=arguments-differ,bad-super-call

        # The original `from_params` for this class was designed in a way that didn't agree
        # with the constructor. The constructor wants a 'token_embedders' parameter that is a
        # `Dict[str, TokenEmbedder]`, but the original `from_params` implementation expected those
        # key-value pairs to be top-level in the params object.
        #
        # This breaks our 'configuration wizard' and configuration checks. Hence, going forward,
        # the params need a 'token_embedders' key so that they line up with what the constructor wants.
        # For now, the old behavior is still supported, but produces a DeprecationWarning.

        embedder_to_indexer_map = params.pop("embedder_to_indexer_map", None)
        if embedder_to_indexer_map is not None:
            embedder_to_indexer_map = embedder_to_indexer_map.as_dict(
                quiet=True)
        allow_unmatched_keys = bool(params.pop("allow_unmatched_keys", False))

        token_embedder_params = params.pop('token_embedders', None)

        if token_embedder_params is not None:
            # New way: explicitly specified, so use it.
            token_embedders = {
                name: Embedding.from_params(vocab=vocab, params=subparams)
                for name, subparams in token_embedder_params.items()
            }

        else:
            # Warn that the original behavior is deprecated
            warnings.warn(
                DeprecationWarning(
                    "the token embedders for BasicTextFieldEmbedder should now "
                    "be specified as a dict under the 'token_embedders' key, "
                    "not as top-level key-value pairs"))

            token_embedders = {}
            keys = list(params.keys())
            for key in keys:
                embedder_params = params.pop(key)
                token_embedders[key] = Embedding.from_params(
                    vocab=vocab, params=embedder_params)

        # TODO(pitrack): replace this line?
        # params.assert_empty(cls.__name__)
        return cls(token_embedders, embedder_to_indexer_map,
                   allow_unmatched_keys)
コード例 #3
0
def load_archive(archive_file: str,
                 device=None,
                 weights_file: str = None) -> Archive:
    """
    Instantiates an Archive from an archived `tar.gz` file.
    Parameters
    ----------
    archive_file: ``str``
        The archive file to load the model from.
    weights_file: ``str``, optional (default = None)
        The weights file to use.  If unspecified, weights.th in the archive_file will be used.
    device: ``None`` or PyTorch device object.
    """
    # redirect to the cache, if necessary
    resolved_archive_file = cached_path(archive_file)

    if resolved_archive_file == archive_file:
        logger.info(f"loading archive file {archive_file}")
    else:
        logger.info(
            f"loading archive file {archive_file} from cache at {resolved_archive_file}"
        )

    tempdir = None
    if os.path.isdir(resolved_archive_file):
        serialization_dir = resolved_archive_file
    else:
        # Extract archive to temp dir
        tempdir = tempfile.mkdtemp()
        logger.info(
            f"extracting archive file {resolved_archive_file} to temp dir {tempdir}"
        )
        with tarfile.open(resolved_archive_file, 'r:gz') as archive:
            archive.extractall(tempdir)

        serialization_dir = tempdir

    # Load config
    config = Params.from_file(os.path.join(serialization_dir, CONFIG_NAME))
    config.loading_from_archive = True

    if weights_file:
        weights_path = weights_file
    else:
        weights_path = os.path.join(serialization_dir, _WEIGHTS_NAME)

    # Instantiate model. Use a duplicate of the config, as it will get consumed.
    model = Model.load(config,
                       weights_file=weights_path,
                       serialization_dir=serialization_dir,
                       device=device)

    if tempdir:
        # Clean up temp dir
        shutil.rmtree(tempdir)

    return Archive(model=model, config=config)
コード例 #4
0
ファイル: train.py プロジェクト: SapienzaNLP/xl-amr
def create_serialization_dir(params: Params) -> None:
    """
    This function creates the serialization directory if it doesn't exist.  If it already exists
    and is non-empty, then it verifies that we're recovering from a training with an identical configuration.
    Parameters
    ----------
    params: ``Params``
        A parameter object specifying an AllenNLP Experiment.
    serialization_dir: ``str``
        The directory in which to save results and logs.
    recover: ``bool``
        If ``True``, we will try to recover from an existing serialization directory, and crash if
        the directory doesn't exist, or doesn't match the configuration we're given.
    """
    serialization_dir = params['environment']['serialization_dir']
    recover = params['environment']['recover']
    if os.path.exists(serialization_dir) and os.listdir(serialization_dir):
        if not recover:
            raise ConfigurationError(f"Serialization directory ({serialization_dir}) already exists and is "
                                     f"not empty. Specify --recover to recover training from existing output.")

        logger.info(f"Recovering from prior training at {serialization_dir}.")

        recovered_config_file = os.path.join(serialization_dir, CONFIG_NAME)
        if not os.path.exists(recovered_config_file):
            raise ConfigurationError("The serialization directory already exists but doesn't "
                                     "contain a config.json. You probably gave the wrong directory.")
        else:
            loaded_params = Params.from_file(recovered_config_file)

            if params != loaded_params:
                raise ConfigurationError("Training configuration does not match the configuration we're "
                                         "recovering from.")

            # In the recover mode, we don't need to reload the pre-trained embeddings.
            remove_pretrained_embedding_params(params)
    else:
        if recover:
            raise ConfigurationError(f"--recover specified but serialization_dir ({serialization_dir}) "
                                     "does not exist.  There is nothing to recover from.")
        os.makedirs(serialization_dir, exist_ok=True)
        params.to_file(os.path.join(serialization_dir, CONFIG_NAME))
コード例 #5
0
ファイル: train.py プロジェクト: SapienzaNLP/xl-amr
def train_model(params: Params):
    """
    Trains the model specified in the given :class:`Params` object, using the data and training
    parameters also specified in that object, and saves the results.
    Parameters
    ----------
    params : ``Params``
        A parameter object specifying an AllenNLP Experiment.
    Returns
    -------
    best_model: ``Model``
        The model with the best epoch weights.
    """
    # Set up the environment.
    environment_params = params['environment']
    environment.set_seed(environment_params)
    create_serialization_dir(params)
    environment.prepare_global_logging(environment_params)
    environment.check_for_gpu(environment_params)
    if environment_params['gpu']:
        device = torch.device('cuda:{}'.format(environment_params['cuda_device']))
        environment.occupy_gpu(device)
    else:
        device = torch.device('cpu')
    params['trainer']['device'] = device

    # Load data.
    data_params = params['data']
    dataset = dataset_from_params(data_params,
                                  universal_postags=params["model"].get('universal_postags',False),
                                  generator_source_copy=data_params.get('source_copy', True),
                                  multilingual=params['model'].get('multilingual',False),
                                  extra_check=params['data'].get('extra_check',False))
    train_data = dataset['train']
    dev_data = dataset.get('dev')
    test_data = dataset.get('test')
    train_mappings = dataset.get('train_mappings',None)
    train_replacements = dataset.get('train_replacements',None)

    # Vocabulary and iterator are created here.
    vocab_params = params.get('vocab', {})
    if "fixed_vocab" in vocab_params and vocab_params["fixed_vocab"]:
        vocab = Vocabulary.from_files("data/vocabulary")
    else:
        vocab = Vocabulary.from_instances(instances=train_data, **vocab_params)

    # Initializing the model can have side effect of expanding the vocabulary
    vocab.save_to_files(os.path.join(environment_params['serialization_dir'], "vocabulary"))
    train_iterator, dev_iterater, test_iterater = iterator_from_params(vocab, data_params['iterator'])
    if train_mappings is not None and train_replacements is not None:
        with open(os.path.join(environment_params['serialization_dir'],"trns_lex_missing.json"),"w", encoding='utf-8') as outfile:
            json.dump(train_mappings[-1], outfile, indent=4, default=serialize_sets)
        with open(os.path.join(environment_params['serialization_dir'],"trns_lexicalizations.json"),"w", encoding='utf-8') as outfile:
            json.dump(train_mappings[-2], outfile, indent=4, default=serialize_sets)
        with open(os.path.join(environment_params['serialization_dir'],"trns_rep.json"), "w", encoding='utf-8') as outfile:
            json.dump(train_replacements, outfile, indent=4, default=serialize_sets)
    # Build the model.
    model_params = params['model']
    model = getattr(Models, model_params['model_type']).from_params(vocab, model_params, environment_params['gpu'], train_mappings, train_replacements)
    logger.info(model)

    # Train
    trainer_params = params['trainer']
    no_grad_regexes = trainer_params['no_grad']
    for name, parameter in model.named_parameters():
        if any(re.search(regex, name) for regex in no_grad_regexes):
            parameter.requires_grad_(False)

    frozen_parameter_names, tunable_parameter_names = \
        environment.get_frozen_and_tunable_parameter_names(model)
    logger.info("Following parameters are Frozen  (without gradient):")
    for name in frozen_parameter_names:
        logger.info(name)
    logger.info("Following parameters are Tunable (with gradient):")
    for name in tunable_parameter_names:
        logger.info(name)

    logger.info("Total nr of parameters Tunable (with gradient):")
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(pytorch_total_params)

    trainer = Trainer.from_params(model, train_data, dev_data, train_iterator, dev_iterater, trainer_params)

    serialization_dir = trainer_params['serialization_dir']
    try:
        metrics = trainer.train()
    except KeyboardInterrupt:
        # if we have completed an epoch, try to create a model archive.
        if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
            logger.info("Training interrupted by the user. Attempting to create "
                         "a model archive using the current best epoch weights.")
            archive_model(serialization_dir)
        raise

    # Now tar up results
    archive_model(serialization_dir)

    logger.info("Loading the best epoch weights.")
    best_model_state_path = os.path.join(serialization_dir, 'best.th')
    best_model_state = torch.load(best_model_state_path)
    best_model = model
    if not isinstance(best_model, torch.nn.DataParallel):
        best_model_state = {re.sub(r'^module\.', '', k):v for k, v in best_model_state.items()}
    best_model.load_state_dict(best_model_state)

    return best_model
コード例 #6
0
ファイル: train.py プロジェクト: SapienzaNLP/xl-amr
        # if we have completed an epoch, try to create a model archive.
        if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
            logger.info("Training interrupted by the user. Attempting to create "
                         "a model archive using the current best epoch weights.")
            archive_model(serialization_dir)
        raise

    # Now tar up results
    archive_model(serialization_dir)

    logger.info("Loading the best epoch weights.")
    best_model_state_path = os.path.join(serialization_dir, 'best.th')
    best_model_state = torch.load(best_model_state_path)
    best_model = model
    if not isinstance(best_model, torch.nn.DataParallel):
        best_model_state = {re.sub(r'^module\.', '', k):v for k, v in best_model_state.items()}
    best_model.load_state_dict(best_model_state)

    return best_model


if __name__ == "__main__":
    parser = argparse.ArgumentParser('train.py')
    parser.add_argument('params', help='Parameters YAML file.')
    args = parser.parse_args()

    params = Params.from_file(args.params)
    logger.info(params)

    train_model(params)
コード例 #7
0
ファイル: from_params.py プロジェクト: SapienzaNLP/xl-amr
def create_kwargs(cls: Type[T], params: Params, **extras) -> Dict[str, Any]:
    """
    Given some class, a `Params` object, and potentially other keyword arguments,
    create a dict of keyword args suitable for passing to the class's constructor.

    The function does this by finding the class's constructor, matching the constructor
    arguments to entries in the `params` object, and instantiating values for the parameters
    using the type annotation and possibly a from_params method.

    Any values that are provided in the `extras` will just be used as is.
    For instance, you might provide an existing `Vocabulary` this way.
    """
    # Get the signature of the constructor.
    signature = inspect.signature(cls.__init__)
    kwargs: Dict[str, Any] = {}

    # Iterate over all the constructor parameters and their annotations.
    for name, param in signature.parameters.items():
        # Skip "self". You're not *required* to call the first parameter "self",
        # so in theory this logic is fragile, but if you don't call the self parameter
        # "self" you kind of deserve what happens.
        if name == "self":
            continue

        # If the annotation is a compound type like typing.Dict[str, int],
        # it will have an __origin__ field indicating `typing.Dict`
        # and an __args__ field indicating `(str, int)`. We capture both.
        annotation = remove_optional(param.annotation)
        origin = getattr(annotation, '__origin__', None)
        args = getattr(annotation, '__args__', [])

        # The parameter is optional if its default value is not the "no default" sentinel.
        default = param.default
        optional = default != _NO_DEFAULT

        # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary.
        # We check the provided `extras` for these and just use them if they exist.
        if name in extras:
            kwargs[name] = extras[name]

        # The next case is when the parameter type is itself constructible from_params.
        elif hasattr(annotation, 'from_params'):
            if name in params:
                # Our params have an entry for this, so we use that.
                subparams = params.pop(name)

                if takes_arg(annotation.from_params, 'extras'):
                    # If annotation.params accepts **extras, we need to pass them all along.
                    # For example, `BasicTextFieldEmbedder.from_params` requires a Vocabulary
                    # object, but `TextFieldEmbedder.from_params` does not.
                    subextras = extras
                else:
                    # Otherwise, only supply the ones that are actual args; any additional ones
                    # will cause a TypeError.
                    subextras = {
                        k: v
                        for k, v in extras.items()
                        if takes_arg(annotation.from_params, k)
                    }

                # In some cases we allow a string instead of a param dict, so
                # we need to handle that case separately.
                if isinstance(subparams, str):
                    kwargs[name] = annotation.by_name(subparams)()
                else:
                    kwargs[name] = annotation.from_params(params=subparams,
                                                          **subextras)
            elif not optional:
                # Not optional and not supplied, that's an error!
                raise ConfigurationError(
                    f"expected key {name} for {cls.__name__}")
            else:
                kwargs[name] = default

        # If the parameter type is a Python primitive, just pop it off
        # using the correct casting pop_xyz operation.
        elif annotation == str:
            kwargs[name] = (params.pop(name, default)
                            if optional else params.pop(name))
        elif annotation == int:
            kwargs[name] = (params.pop_int(name, default)
                            if optional else params.pop_int(name))
        elif annotation == bool:
            kwargs[name] = (params.pop_bool(name, default)
                            if optional else params.pop_bool(name))
        elif annotation == float:
            kwargs[name] = (params.pop_float(name, default)
                            if optional else params.pop_float(name))

        # This is special logic for handling types like Dict[str, TokenIndexer], which it creates by
        # instantiating each value from_params and returning the resulting dict.
        elif origin in (Dict, dict) and len(args) == 2 and hasattr(
                args[-1], 'from_params'):
            value_cls = annotation.__args__[-1]

            value_dict = {}

            for key, value_params in params.pop(name, Params({})).items():
                value_dict[key] = value_cls.from_params(params=value_params,
                                                        **extras)

            kwargs[name] = value_dict

        else:
            # Pass it on as is and hope for the best.   ¯\_(ツ)_/¯
            if optional:
                kwargs[name] = params.pop(name, default)
            else:
                kwargs[name] = params.pop(name)

    params.assert_empty(cls.__name__)
    return kwargs
コード例 #8
0
ファイル: from_params.py プロジェクト: SapienzaNLP/xl-amr
    def from_params(cls: Type[T], params: Params, **extras) -> T:
        """
        This is the automatic implementation of `from_params`. Any class that subclasses `FromParams`
        (or `Registrable`, which itself subclasses `FromParams`) gets this implementation for free.
        If you want your class to be instantiated from params in the "obvious" way -- pop off parameters
        and hand them to your constructor with the same names -- this provides that functionality.

        If you need more complex logic in your from `from_params` method, you'll have to implement
        your own method that overrides this one.
        """
        # pylint: disable=protected-access
        from xlamr_stog.utils.registrable import Registrable  # import here to avoid circular imports

        logger.info(
            f"instantiating class {cls} from params {getattr(params, 'params', params)} "
            f"and extras {extras}")

        if params is None:
            return None

        registered_subclasses = Registrable._registry.get(cls)

        if registered_subclasses is not None:
            # We know ``cls`` inherits from Registrable, so we'll use a cast to make mypy happy.
            # We have to use a disable to make pylint happy.
            # pylint: disable=no-member
            as_registrable = cast(Type[Registrable], cls)
            default_to_first_choice = as_registrable.default_implementation is not None
            choice = params.pop_choice(
                "type",
                choices=as_registrable.list_available(),
                default_to_first_choice=default_to_first_choice)
            subclass = registered_subclasses[choice]

            # We want to call subclass.from_params. It's possible that it's just the "free"
            # implementation here, in which case it accepts `**extras` and we are not able
            # to make any assumptions about what extra parameters it needs.
            #
            # It's also possible that it has a custom `from_params` method. In that case it
            # won't accept any **extra parameters and we'll need to filter them out.
            if not takes_arg(subclass.from_params, 'extras'):
                # Necessarily subclass.from_params is a custom implementation, so we need to
                # pass it only the args it's expecting.
                extras = {
                    k: v
                    for k, v in extras.items()
                    if takes_arg(subclass.from_params, k)
                }

            return subclass.from_params(params=params, **extras)
        else:
            # This is not a base class, so convert our params and extras into a dict of kwargs.

            if cls.__init__ == object.__init__:
                # This class does not have an explicit constructor, so don't give it any kwargs.
                # Without this logic, create_kwargs will look at object.__init__ and see that
                # it takes *args and **kwargs and look for those.
                kwargs: Dict[str, Any] = {}
            else:
                # This class has a constructor, so create kwargs for it.
                kwargs = create_kwargs(cls, params, **extras)

            return cls(**kwargs)  # type: ignore