Exemple #1
0
    def test_train_read(self):
        self.reader = Flickr30kReader(
            image_dir=FIXTURES_ROOT / "vision" / "images" / "flickr30k",
            image_loader=TorchImageLoader(),
            image_featurizer=Lazy(NullGridEmbedder),
            data_dir=FIXTURES_ROOT / "vision" / "flickr30k" / "sentences",
            region_detector=Lazy(RandomRegionDetector),
            tokenizer=WhitespaceTokenizer(),
            token_indexers={"tokens": SingleIdTokenIndexer()},
            featurize_captions=False,
            num_potential_hard_negatives=4,
        )

        instances = list(self.reader.read("test_fixtures/vision/flickr30k/test.txt"))
        assert len(instances) == 25

        instance = instances[5]
        assert len(instance.fields) == 5
        assert len(instance["caption"]) == 4
        assert len(instance["caption"][0]) == 12  # 16
        assert instance["caption"][0] != instance["caption"][1]
        assert instance["caption"][0] == instance["caption"][2]
        assert instance["caption"][0] == instance["caption"][3]
        question_tokens = [t.text for t in instance["caption"][0]]
        assert question_tokens == [
            "girl",
            "with",
            "brown",
            "hair",
            "sits",
            "on",
            "edge",
            "of",
            "concrete",
            "area",
            "overlooking",
            "water",
        ]

        batch = Batch(instances)
        batch.index_instances(Vocabulary())
        tensors = batch.as_tensor_dict()

        # (batch size, num images (3 hard negatives + gold image), num boxes (fake), num features (fake))
        assert tensors["box_features"].size() == (25, 4, 2, 10)

        # (batch size, num images (3 hard negatives + gold image), num boxes (fake), 4 coords)
        assert tensors["box_coordinates"].size() == (25, 4, 2, 4)

        # (batch size, num images (3 hard negatives + gold image), num boxes (fake),)
        assert tensors["box_mask"].size() == (25, 4, 2)

        # (batch size)
        assert tensors["label"].size() == (25,)
    def setup_method(self):
        from allennlp_models.vision.dataset_readers.gqa import GQAReader

        super().setup_method()
        self.reader = GQAReader(
            image_dir=FIXTURES_ROOT / "vision" / "images" / "gqa",
            image_loader=TorchImageLoader(),
            image_featurizer=Lazy(NullGridEmbedder),
            region_detector=Lazy(RandomRegionDetector),
            tokenizer=WhitespaceTokenizer(),
            token_indexers={"tokens": SingleIdTokenIndexer()},
        )
Exemple #3
0
    def __init__(self,
                 vocab: Vocabulary,
                 model_name: str,
                 beam_search: Lazy[BeamSearch] = Lazy(BeamSearch,
                                                      beam_size=3,
                                                      max_steps=50),
                 checkpoint_wrapper: Optional[CheckpointWrapper] = None,
                 weights_path: Optional[Union[str, PathLike]] = None,
                 **kwargs) -> None:
        super().__init__(vocab, **kwargs)
        self._model_name = model_name
        # We only instantiate this when we need it.
        self._tokenizer: Optional[PretrainedTransformerTokenizer] = None
        self.t5 = T5Module.from_pretrained_module(
            model_name,
            beam_search=beam_search,
            ddp_accelerator=self.ddp_accelerator,
            checkpoint_wrapper=checkpoint_wrapper,
            weights_path=weights_path,
        )

        exclude_indices = {
            self.t5.pad_token_id,
            self.t5.decoder_start_token_id,
            self.t5.eos_token_id,
        }
        self._metrics = [
            ROUGE(exclude_indices=exclude_indices),
            BLEU(exclude_indices=exclude_indices),
        ]
Exemple #4
0
def construct_from_params(value_cls: Type[T], value_params: Params,
                          extras: Dict[str, Any]) -> T:
    """
    At this point we know that we need to use `from_params` to construct an object that will be
    (part of) an argument to a constructor.  This does the logic of actually calling that
    `from_params` method.

    This is normally as simple as just `value_cls.from_params`, but we first call `create_extras` to
    pass along any **kwargs that we got as input, and we also have some special handling for `Lazy`
    annotations here - we don't want to recurse on `Lazy.from_params`, we want to bypass that.
    """
    origin = getattr(value_cls, "__origin__", None)
    if origin == Lazy:
        value_cls = value_cls.__args__[0]  # type: ignore
        subextras = create_extras(value_cls, extras)

        def constructor(**kwargs):
            return value_cls.from_params(params=value_params,
                                         **kwargs,
                                         **subextras)

        return Lazy(constructor)  # type: ignore
    else:
        subextras = create_extras(value_cls, extras)
        return value_cls.from_params(params=value_params, **subextras)
    def test_read(self):
        from allennlp_models.vision.dataset_readers.vqav2 import VQAv2Reader

        reader = VQAv2Reader(
            image_dir=FIXTURES_ROOT / "vision" / "images" / "vqav2",
            image_loader=TorchImageLoader(),
            image_featurizer=Lazy(NullGridEmbedder),
            region_detector=Lazy(RandomRegionDetector),
            tokenizer=WhitespaceTokenizer(),
            token_indexers={"tokens": SingleIdTokenIndexer()},
        )
        instances = list(reader.read("unittest"))
        assert len(instances) == 3

        instance = instances[0]
        assert len(instance.fields) == 6
        assert len(instance["question"]) == 7
        question_tokens = [t.text for t in instance["question"]]
        assert question_tokens == [
            "What", "is", "this", "photo", "taken", "looking", "through?"
        ]
        assert len(instance["labels"]) == 5
        labels = [field.label for field in instance["labels"].field_list]
        assert labels == ["net", "netting", "mesh", "pitcher", "orange"]
        assert torch.allclose(
            instance["label_weights"].tensor,
            torch.tensor([1.0, 1.0 / 3, 1.0 / 3, 1.0 / 3, 1.0 / 3]),
        )

        batch = Batch(instances)
        batch.index_instances(Vocabulary())
        tensors = batch.as_tensor_dict()

        # (batch size, num boxes (fake), num features (fake))
        assert tensors["box_features"].size() == (3, 2, 10)

        # (batch size, num boxes (fake), 4 coords)
        assert tensors["box_coordinates"].size() == (3, 2, 4)

        # (batch size, num boxes (fake),)
        assert tensors["box_mask"].size() == (3, 2)

        # Nothing should be masked out since the number of fake boxes is the same
        # for each item in the batch.
        assert tensors["box_mask"].all()
Exemple #6
0
def pop_and_construct_arg(
    class_name: str, argument_name: str, annotation: Type, default: Any, params: Params, **extras
) -> Any:
    """
    Does the work of actually constructing an individual argument for
    [`create_kwargs`](./from_params#create_kwargs).

    Here we're in the inner loop of iterating over the parameters to a particular constructor,
    trying to construct just one of them.  The information we get for that parameter is its name,
    its type annotation, and its default value; we also get the full set of `Params` for
    constructing the object (which we may mutate), and any `extras` that the constructor might
    need.

    We take the type annotation and default value here separately, instead of using an
    `inspect.Parameter` object directly, so that we can handle `Union` types using recursion on
    this method, trying the different annotation types in the union in turn.
    """
    from allennlp.models.archival import load_archive  # import here to avoid circular imports

    # We used `argument_name` as the method argument to avoid conflicts with 'name' being a key in
    # `extras`, which isn't _that_ unlikely.  Now that we are inside the method, we can switch back
    # to using `name`.
    name = argument_name

    # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary.
    # We check the provided `extras` for these and just use them if they exist.
    if name in extras:
        return extras[name]
    # Next case is when argument should be loaded from pretrained archive.
    elif (
        name in params
        and isinstance(params.get(name), Params)
        and "_pretrained" in params.get(name)
    ):
        load_module_params = params.pop(name).pop("_pretrained")
        archive_file = load_module_params.pop("archive_file")
        module_path = load_module_params.pop("module_path")
        freeze = load_module_params.pop("freeze", True)
        archive = load_archive(archive_file)
        result = archive.extract_module(module_path, freeze)
        if not isinstance(result, annotation):
            raise ConfigurationError(
                f"The module from model at {archive_file} at path {module_path} "
                f"was expected of type {annotation} but is of type {type(result)}"
            )
        return result

    popped_params = params.pop(name, default) if default != _NO_DEFAULT else params.pop(name)
    if popped_params is None:
        origin = getattr(annotation, "__origin__", None)
        if origin == Lazy:
            return Lazy(lambda **kwargs: None)
        return None

    return construct_arg(class_name, name, popped_params, annotation, default, **extras)
    def test_read(self):
        from allennlp_models.vision.dataset_readers.vgqa import VGQAReader

        reader = VGQAReader(
            image_dir=FIXTURES_ROOT / "vision" / "images" / "vgqa",
            image_loader=TorchImageLoader(),
            image_featurizer=Lazy(NullGridEmbedder),
            region_detector=Lazy(RandomRegionDetector),
            tokenizer=WhitespaceTokenizer(),
            token_indexers={"tokens": SingleIdTokenIndexer()},
        )
        instances = list(
            reader.read("test_fixtures/vision/vgqa/question_answers.json"))
        assert len(instances) == 8

        instance = instances[0]
        assert len(instance.fields) == 6
        assert len(instance["question"]) == 5
        question_tokens = [t.text for t in instance["question"]]
        assert question_tokens == ["What", "is", "on", "the", "curtains?"]
        assert len(instance["labels"]) == 1
        labels = [field.label for field in instance["labels"].field_list]
        assert labels == ["sailboats"]

        batch = Batch(instances)
        batch.index_instances(Vocabulary())
        tensors = batch.as_tensor_dict()

        # (batch size, num boxes (fake), num features (fake))
        assert tensors["box_features"].size() == (8, 2, 10)

        # (batch size, num boxes (fake), 4 coords)
        assert tensors["box_coordinates"].size() == (8, 2, 4)

        # (batch size, num boxes (fake))
        assert tensors["box_mask"].size() == (8, 2)

        # Nothing should be masked out since the number of fake boxes is the same
        # for each item in the batch.
        assert tensors["box_mask"].all()
    def test_read(self):
        from allennlp_models.vision.dataset_readers.nlvr2 import Nlvr2Reader

        reader = Nlvr2Reader(
            image_dir=FIXTURES_ROOT / "vision" / "images" / "nlvr2",
            image_loader=TorchImageLoader(),
            image_featurizer=Lazy(NullGridEmbedder),
            region_detector=Lazy(RandomRegionDetector),
            tokenizer=WhitespaceTokenizer(),
            token_indexers={"tokens": SingleIdTokenIndexer()},
        )
        instances = list(
            reader.read("test_fixtures/vision/nlvr2/tiny-dev.json"))
        assert len(instances) == 8

        instance = instances[0]
        assert len(instance.fields) == 6
        assert instance["hypothesis"][0] == instance["hypothesis"][1]
        assert len(instance["hypothesis"][0]) == 18
        hypothesis_tokens = [t.text for t in instance["hypothesis"][0]]
        assert hypothesis_tokens[:6] == [
            "The", "right", "image", "shows", "a", "curving"
        ]
        assert instance["label"].label == 0
        assert instances[1]["label"].label == 1
        assert instance["identifier"].metadata == "dev-850-0-0"

        batch = Batch(instances)
        batch.index_instances(Vocabulary())
        tensors = batch.as_tensor_dict()

        # (batch size, 2 images per instance, num boxes (fake), num features (fake))
        assert tensors["box_features"].size() == (8, 2, 2, 10)

        # (batch size, 2 images per instance, num boxes (fake), 4 coords)
        assert tensors["box_coordinates"].size() == (8, 2, 2, 4)

        # (batch size, 2 images per instance, num boxes (fake))
        assert tensors["box_mask"].size() == (8, 2, 2)
Exemple #9
0
    def test_read(self):
        from allennlp_models.vision.dataset_readers.visual_entailment import VisualEntailmentReader

        reader = VisualEntailmentReader(
            image_dir=FIXTURES_ROOT / "vision" / "images" /
            "visual_entailment",
            image_loader=TorchImageLoader(),
            image_featurizer=Lazy(NullGridEmbedder),
            region_detector=Lazy(RandomRegionDetector),
            tokenizer=WhitespaceTokenizer(),
            token_indexers={"tokens": SingleIdTokenIndexer()},
        )
        instances = list(
            reader.read(
                "test_fixtures/vision/visual_entailment/sample_pairs.jsonl"))
        assert len(instances) == 16

        instance = instances[0]
        assert len(instance.fields) == 5
        assert len(instance["hypothesis"]) == 4
        sentence_tokens = [t.text for t in instance["hypothesis"]]
        assert sentence_tokens == ["A", "toddler", "sleeps", "outside."]
        assert instance["labels"].label == "contradiction"

        batch = Batch(instances)
        vocab = Vocabulary()
        vocab.add_tokens_to_namespace(
            ["entailment", "contradiction", "neutral"], "labels")
        batch.index_instances(vocab)
        tensors = batch.as_tensor_dict()

        # (batch size, num boxes (fake), num features (fake))
        assert tensors["box_features"].size() == (16, 2, 10)

        # (batch size, num boxes (fake), 4 coords)
        assert tensors["box_coordinates"].size() == (16, 2, 4)

        # (batch_size, num boxes (fake),)
        assert tensors["box_mask"].size() == (16, 2)
Exemple #10
0
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        beam_search: Lazy[BeamSearch] = Lazy(BeamSearch),
        indexer: PretrainedTransformerIndexer = None,
        encoder: Seq2SeqEncoder = None,
        **kwargs,
    ):
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        # At prediction time, we'll use a beam search to find the best target sequence.
        # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as
        # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning
        deprecation_warning = (
            "The parameter {} has been deprecated."
            " Provide this parameter as argument to beam_search instead."
        )
        beam_search_extras = {}
        if "beam_size" in kwargs:
            beam_search_extras["beam_size"] = kwargs["beam_size"]
            warnings.warn(deprecation_warning.format("beam_size"), DeprecationWarning)
        if "max_decoding_steps" in kwargs:
            beam_search_extras["max_steps"] = kwargs["max_decoding_steps"]
            warnings.warn(deprecation_warning.format("max_decoding_steps"), DeprecationWarning)
        self._beam_search = beam_search.construct(
            end_index=self._end_id, vocab=self.vocab, **beam_search_extras
        )

        self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (
                encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size
            )
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 beam_search: Lazy[BeamSearch] = Lazy(BeamSearch),
                 attention: Attention = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.0,
                 use_bleu: bool = True,
                 bleu_ngram_weights: Iterable[float] = (0.25, 0.25, 0.25,
                                                        0.25),
                 target_pretrain_file: str = None,
                 target_decoder_layers: int = 1,
                 **kwargs) -> None:
        super().__init__(vocab)
        self._target_namespace = target_namespace
        self._target_decoder_layers = target_decoder_layers
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)
            self._bleu = BLEU(
                bleu_ngram_weights,
                exclude_indices={
                    pad_index, self._end_index, self._start_index
                },
            )
        else:
            self._bleu = None

        # At prediction time, we'll use a beam search to find the best target sequence.
        # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as
        # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning
        deprecation_warning = (
            "The parameter {} has been deprecated."
            " Provide this parameter as argument to beam_search instead.")
        beam_search_extras = {}
        if "beam_size" in kwargs:
            beam_search_extras["beam_size"] = kwargs["beam_size"]
            warnings.warn(deprecation_warning.format("beam_size"),
                          DeprecationWarning)
        if "max_decoding_steps" in kwargs:
            beam_search_extras["max_steps"] = kwargs["max_decoding_steps"]
            warnings.warn(deprecation_warning.format("max_decoding_steps"),
                          DeprecationWarning)
        self._beam_search = beam_search.construct(end_index=self._end_index,
                                                  vocab=self.vocab,
                                                  **beam_search_extras)

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        self._attention = attention

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim(
        )
        if not target_pretrain_file:
            self._target_embedder = Embedding(
                num_embeddings=num_classes, embedding_dim=target_embedding_dim)
        else:
            self._target_embedder = Embedding(
                embedding_dim=target_embedding_dim,
                pretrained_file=target_pretrain_file,
                vocab_namespace=self._target_namespace,
                vocab=self.vocab,
            )

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        if self._target_decoder_layers > 1:
            self._decoder_cell = LSTM(
                self._decoder_input_dim,
                self._decoder_output_dim,
                self._target_decoder_layers,
            )
        else:
            self._decoder_cell = LSTMCell(self._decoder_input_dim,
                                          self._decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)
Exemple #12
0
def construct_arg(
    class_name: str,
    argument_name: str,
    popped_params: Params,
    annotation: Type,
    default: Any,
    **extras,
) -> Any:
    """
    The first two parameters here are only used for logging if we encounter an error.
    """
    origin = getattr(annotation, "__origin__", None)
    args = getattr(annotation, "__args__", [])

    # The parameter is optional if its default value is not the "no default" sentinel.
    optional = default != _NO_DEFAULT

    if hasattr(annotation, "from_params"):
        if popped_params is default:
            return default
        elif popped_params is not None:
            # Our params have an entry for this, so we use that.

            subextras = create_extras(annotation, extras)

            # In some cases we allow a string instead of a param dict, so
            # we need to handle that case separately.
            if isinstance(popped_params, str):
                return annotation.by_name(popped_params)()
            else:
                if isinstance(popped_params, dict):
                    popped_params = Params(popped_params)
                return annotation.from_params(params=popped_params,
                                              **subextras)
        elif not optional:
            # Not optional and not supplied, that's an error!
            raise ConfigurationError(
                f"expected key {argument_name} for {class_name}")
        else:
            return default

    # If the parameter type is a Python primitive, just pop it off
    # using the correct casting pop_xyz operation.
    elif annotation in {int, bool}:
        if type(popped_params) in {int, bool}:
            return annotation(popped_params)
        else:
            raise TypeError(
                f"Expected {argument_name} to be a {annotation.__name__}.")
    elif annotation == str:
        # Strings are special because we allow casting from Path to str.
        if type(popped_params) == str or isinstance(popped_params, Path):
            return str(popped_params)  # type: ignore
        else:
            raise TypeError(f"Expected {argument_name} to be a string.")
    elif annotation == float:
        # Floats are special because in Python, you can put an int wherever you can put a float.
        # https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html
        if type(popped_params) in {int, float}:
            return popped_params
        else:
            raise TypeError(f"Expected {argument_name} to be numeric.")

    # This is special logic for handling types like Dict[str, TokenIndexer],
    # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer],
    # which it creates by instantiating each value from_params and returning the resulting structure.
    elif origin in (Dict,
                    dict) and len(args) == 2 and can_construct_from_params(
                        args[-1]):
        value_cls = annotation.__args__[-1]

        value_dict = {}

        for key, value_params in popped_params.items():
            value_dict[key] = construct_arg(
                str(value_cls),
                argument_name + "." + key,
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )

        return value_dict

    elif origin in (List,
                    list) and len(args) == 1 and can_construct_from_params(
                        args[0]):
        value_cls = annotation.__args__[0]

        value_list = []

        for i, value_params in enumerate(popped_params):
            value = construct_arg(
                str(value_cls),
                argument_name + f".{i}",
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )
            value_list.append(value)

        return value_list

    elif origin in (Tuple, tuple) and all(
            can_construct_from_params(arg) for arg in args):
        value_list = []

        for i, (value_cls, value_params) in enumerate(
                zip(annotation.__args__, popped_params)):
            value = construct_arg(
                str(value_cls),
                argument_name + f".{i}",
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )
            value_list.append(value)

        return tuple(value_list)

    elif origin in (Set, set) and len(args) == 1 and can_construct_from_params(
            args[0]):
        value_cls = annotation.__args__[0]

        value_set = set()

        for i, value_params in enumerate(popped_params):
            value = construct_arg(
                str(value_cls),
                argument_name + f".{i}",
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )
            value_set.add(value)

        return value_set

    elif origin == Union:
        # Storing this so we can recover it later if we need to.
        backup_params = deepcopy(popped_params)

        # We'll try each of the given types in the union sequentially, returning the first one that
        # succeeds.
        for arg_annotation in args:
            try:
                return construct_arg(
                    str(arg_annotation),
                    argument_name,
                    popped_params,
                    arg_annotation,
                    default,
                    **extras,
                )
            except (ValueError, TypeError, ConfigurationError, AttributeError):
                # Our attempt to construct the argument may have modified popped_params, so we
                # restore it here.
                popped_params = deepcopy(backup_params)

        # If none of them succeeded, we crash.
        raise ConfigurationError(
            f"Failed to construct argument {argument_name} with type {annotation}"
        )
    elif origin == Lazy:
        if popped_params is default:
            return Lazy(lambda **kwargs: default)
        value_cls = args[0]
        subextras = create_extras(value_cls, extras)

        def constructor(**kwargs):
            # If there are duplicate keys between subextras and kwargs, this will overwrite the ones
            # in subextras with what's in kwargs.  If an argument shows up twice, we should take it
            # from what's passed to Lazy.construct() instead of what we got from create_extras().
            # Almost certainly these will be identical objects, anyway.
            # We do this by constructing a new dictionary, instead of mutating subextras, just in
            # case this constructor is called multiple times.
            constructor_extras = {**subextras, **kwargs}
            return value_cls.from_params(params=deepcopy(popped_params),
                                         **constructor_extras)

        return Lazy(constructor)  # type: ignore
    else:
        # Pass it on as is and hope for the best.   ¯\_(ツ)_/¯
        if isinstance(popped_params, Params):
            return popped_params.as_dict(quiet=True)
        return popped_params
Exemple #13
0
def construct_arg(class_name: str, param_name: str, annotation: Type,
                  default: Any, params: Params, **extras) -> Any:
    """
    Does the work of actually constructing an individual argument for :func:`create_kwargs`.

    Here we're in the inner loop of iterating over the parameters to a particular constructor,
    trying to construct just one of them.  The information we get for that parameter is its name,
    its type annotation, and its default value; we also get the full set of `Params` for
    constructing the object (which we may mutate), and any `extras` that the constructor might
    need.

    We take the type annotation and default value here separately, instead of using an
    `inspect.Parameter` object directly, so that we can handle `Union` types using recursion on
    this method, trying the different annotation types in the union in turn.
    """
    from allennlp.models.archival import load_archive  # import here to avoid circular imports

    # We used `param_name` as the method argument to avoid conflicts with 'name' being a key in
    # `extras`, which isn't _that_ unlikely.  Now that we are inside the method, we can switch back
    # to using `name`.
    name = param_name
    origin = getattr(annotation, "__origin__", None)
    args = getattr(annotation, "__args__", [])

    # The parameter is optional if its default value is not the "no default" sentinel.
    optional = default != _NO_DEFAULT

    # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary.
    # We check the provided `extras` for these and just use them if they exist.
    if name in extras:
        return extras[name]
    # Next case is when argument should be loaded from pretrained archive.
    elif (name in params and isinstance(params.get(name), Params)
          and "_pretrained" in params.get(name)):
        load_module_params = params.pop(name).pop("_pretrained")
        archive_file = load_module_params.pop("archive_file")
        module_path = load_module_params.pop("module_path")
        freeze = load_module_params.pop("freeze", True)
        archive = load_archive(archive_file)
        result = archive.extract_module(module_path, freeze)
        if not isinstance(result, annotation):
            raise ConfigurationError(
                f"The module from model at {archive_file} at path {module_path} "
                f"was expected of type {annotation} but is of type {type(result)}"
            )
        return result
    # The next case is when the parameter type is itself constructible from_params.
    elif hasattr(annotation, "from_params"):
        if name in params:
            # Our params have an entry for this, so we use that.
            subparams = params.pop(name)

            subextras = create_extras(annotation, extras)

            # In some cases we allow a string instead of a param dict, so
            # we need to handle that case separately.
            if isinstance(subparams, str):
                return annotation.by_name(subparams)()
            else:
                return annotation.from_params(params=subparams, **subextras)
        elif not optional:
            # Not optional and not supplied, that's an error!
            raise ConfigurationError(f"expected key {name} for {class_name}")
        else:
            return default

    # If the parameter type is a Python primitive, just pop it off
    # using the correct casting pop_xyz operation.
    elif annotation == str:
        return params.pop(name, default) if optional else params.pop(name)
    elif annotation == int:
        return params.pop_int(name,
                              default) if optional else params.pop_int(name)
    elif annotation == bool:
        return params.pop_bool(name,
                               default) if optional else params.pop_bool(name)
    elif annotation == float:
        return params.pop_float(
            name, default) if optional else params.pop_float(name)

    # This is special logic for handling types like Dict[str, TokenIndexer],
    # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer],
    # which it creates by instantiating each value from_params and returning the resulting structure.
    elif origin in (Dict,
                    dict) and len(args) == 2 and can_construct_from_params(
                        args[-1]):
        value_cls = annotation.__args__[-1]

        value_dict = {}

        for key, value_params in params.pop(name, Params({})).items():
            value_dict[key] = construct_from_params(value_cls, value_params,
                                                    extras)

        return value_dict

    elif origin in (List,
                    list) and len(args) == 1 and can_construct_from_params(
                        args[0]):
        value_cls = annotation.__args__[0]

        value_list = []

        for value_params in params.pop(name, Params({})):
            value_list.append(
                construct_from_params(value_cls, value_params, extras))

        return value_list

    elif origin in (Tuple, tuple) and all(
            can_construct_from_params(arg) for arg in args):
        value_list = []

        for value_cls, value_params in zip(annotation.__args__,
                                           params.pop(name, Params({}))):
            value_list.append(
                construct_from_params(value_cls, value_params, extras))

        return tuple(value_list)

    elif origin in (Set, set) and len(args) == 1 and can_construct_from_params(
            args[0]):
        value_cls = annotation.__args__[0]

        value_set = set()

        for value_params in params.pop(name, Params({})):
            value_set.add(
                construct_from_params(value_cls, value_params, extras))

        return value_set

    elif origin == Union:
        # Storing this so we can recover it later if we need to.
        param_value = params.pop(name, default=default, keep_as_dict=True)
        params[name] = deepcopy(param_value)

        # We'll try each of the given types in the union sequentially, returning the first one that
        # succeeds.
        for arg in args:
            try:
                return construct_arg(class_name, name, arg, default, params,
                                     **extras)
            except (ValueError, TypeError, ConfigurationError, AttributeError):
                # Our attempt to construct the argument may have popped `params[name]`, so we
                # restore it here.
                params[name] = param_value
                param_value = deepcopy(param_value)
                continue

        # If none of them succeeded, we crash.
        raise ConfigurationError(
            f"Failed to construct argument {name} with type {annotation}")
    elif origin == Lazy:
        if name not in params and optional:
            return Lazy(lambda **kwargs: default)
        value_params = params.pop(name, Params({}))
        return construct_from_params(annotation, value_params, extras)
    else:
        # Pass it on as is and hope for the best.   ¯\_(ツ)_/¯
        if optional:
            value = params.pop(name, default, keep_as_dict=True)
        else:
            value = params.pop(name, keep_as_dict=True)
        return value
Exemple #14
0
def construct_arg(
    class_name: str,
    argument_name: str,
    popped_params: Params,
    annotation: Type,
    default: Any,
    **extras,
) -> Any:
    """
    The first two parameters here are only used for logging if we encounter an error.
    """
    origin = getattr(annotation, "__origin__", None)
    args = getattr(annotation, "__args__", [])

    # The parameter is optional if its default value is not the "no default" sentinel.
    optional = default != _NO_DEFAULT

    if hasattr(annotation, "from_params"):
        if popped_params is default:
            return default
        elif popped_params is not None:
            # Our params have an entry for this, so we use that.

            subextras = create_extras(annotation, extras)

            # In some cases we allow a string instead of a param dict, so
            # we need to handle that case separately.
            if isinstance(popped_params, str):
                popped_params = Params({"type": popped_params})
            elif isinstance(popped_params, dict):
                popped_params = Params(popped_params)
            return annotation.from_params(params=popped_params, **subextras)
        elif not optional:
            # Not optional and not supplied, that's an error!
            raise ConfigurationError(
                f"expected key {argument_name} for {class_name}")
        else:
            return default

    # If the parameter type is a Python primitive, just pop it off
    # using the correct casting pop_xyz operation.
    elif annotation in {int, bool}:
        if type(popped_params) in {int, bool}:
            return annotation(popped_params)
        else:
            raise TypeError(
                f"Expected {argument_name} to be a {annotation.__name__}.")
    elif annotation == str:
        # Strings are special because we allow casting from Path to str.
        if type(popped_params) == str or isinstance(popped_params, Path):
            return str(popped_params)  # type: ignore
        else:
            raise TypeError(f"Expected {argument_name} to be a string.")
    elif annotation == float:
        # Floats are special because in Python, you can put an int wherever you can put a float.
        # https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html
        if type(popped_params) in {int, float}:
            return popped_params
        else:
            raise TypeError(f"Expected {argument_name} to be numeric.")

    # This is special logic for handling types like Dict[str, TokenIndexer],
    # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer],
    # which it creates by instantiating each value from_params and returning the resulting structure.
    elif (origin in {collections.abc.Mapping, Mapping, Dict, dict}
          and len(args) == 2 and can_construct_from_params(args[-1])):
        value_cls = annotation.__args__[-1]
        value_dict = {}
        if not isinstance(popped_params, Mapping):
            raise TypeError(
                f"Expected {argument_name} to be a Mapping (probably a dict or a Params object)."
            )

        for key, value_params in popped_params.items():
            value_dict[key] = construct_arg(
                str(value_cls),
                argument_name + "." + key,
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )

        return value_dict

    elif origin in (Tuple, tuple) and all(
            can_construct_from_params(arg) for arg in args):
        value_list = []

        for i, (value_cls, value_params) in enumerate(
                zip(annotation.__args__, popped_params)):
            value = construct_arg(
                str(value_cls),
                argument_name + f".{i}",
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )
            value_list.append(value)

        return tuple(value_list)

    elif origin in (Set, set) and len(args) == 1 and can_construct_from_params(
            args[0]):
        value_cls = annotation.__args__[0]

        value_set = set()

        for i, value_params in enumerate(popped_params):
            value = construct_arg(
                str(value_cls),
                argument_name + f".{i}",
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )
            value_set.add(value)

        return value_set

    elif origin == Union:
        # Storing this so we can recover it later if we need to.
        backup_params = deepcopy(popped_params)

        # We'll try each of the given types in the union sequentially, returning the first one that
        # succeeds.
        for arg_annotation in args:
            try:
                return construct_arg(
                    str(arg_annotation),
                    argument_name,
                    popped_params,
                    arg_annotation,
                    default,
                    **extras,
                )
            except (ValueError, TypeError, ConfigurationError, AttributeError):
                # Our attempt to construct the argument may have modified popped_params, so we
                # restore it here.
                popped_params = deepcopy(backup_params)

        # If none of them succeeded, we crash.
        raise ConfigurationError(
            f"Failed to construct argument {argument_name} with type {annotation}"
        )
    elif origin == Lazy:
        if popped_params is default:
            return default

        value_cls = args[0]
        subextras = create_extras(value_cls, extras)
        return Lazy(value_cls,
                    params=deepcopy(popped_params),
                    contructor_extras=subextras)  # type: ignore

    # For any other kind of iterable, we will just assume that a list is good enough, and treat
    # it the same as List. This condition needs to be at the end, so we don't catch other kinds
    # of Iterables with this branch.
    elif (origin in {collections.abc.Iterable, Iterable, List, list}
          and len(args) == 1 and can_construct_from_params(args[0])):
        value_cls = annotation.__args__[0]

        value_list = []

        for i, value_params in enumerate(popped_params):
            value = construct_arg(
                str(value_cls),
                argument_name + f".{i}",
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )
            value_list.append(value)

        return value_list

    else:
        # Pass it on as is and hope for the best.   ¯\_(ツ)_/¯
        if isinstance(popped_params, Params):
            return popped_params.as_dict()
        return popped_params
Exemple #15
0
def construct_arg(
    class_name: str,
    argument_name: str,
    popped_params: Params,
    annotation: Type,
    default: Any,
    **extras,
) -> Any:
    """
    The first two parameters here are only used for logging if we encounter an error.
    """
    origin = getattr(annotation, "__origin__", None)
    args = getattr(annotation, "__args__", [])

    # The parameter is optional if its default value is not the "no default" sentinel.
    optional = default != _NO_DEFAULT

    if hasattr(annotation, "from_params"):
        if popped_params is default:
            return default
        elif popped_params is not None:
            # Our params have an entry for this, so we use that.

            subextras = create_extras(annotation, extras)

            # In some cases we allow a string instead of a param dict, so
            # we need to handle that case separately.
            if isinstance(popped_params, str):
                return annotation.by_name(popped_params)()
            else:
                if isinstance(popped_params, dict):
                    popped_params = Params(popped_params)
                return annotation.from_params(params=popped_params, **subextras)
        elif not optional:
            # Not optional and not supplied, that's an error!
            raise ConfigurationError(f"expected key {argument_name} for {class_name}")
        else:
            return default

    # If the parameter type is a Python primitive, just pop it off
    # using the correct casting pop_xyz operation.
    elif annotation == str:
        return popped_params
    elif annotation == int:
        return int(popped_params)  # type: ignore
    elif annotation == bool:
        return bool(popped_params)
    elif annotation == float:
        return float(popped_params)  # type: ignore

    # This is special logic for handling types like Dict[str, TokenIndexer],
    # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer],
    # which it creates by instantiating each value from_params and returning the resulting structure.
    elif origin in (Dict, dict) and len(args) == 2 and can_construct_from_params(args[-1]):
        value_cls = annotation.__args__[-1]

        value_dict = {}

        for key, value_params in popped_params.items():
            value_dict[key] = construct_arg(
                str(value_cls),
                argument_name + "." + key,
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )

        return value_dict

    elif origin in (List, list) and len(args) == 1 and can_construct_from_params(args[0]):
        value_cls = annotation.__args__[0]

        value_list = []

        for i, value_params in enumerate(popped_params):
            value = construct_arg(
                str(value_cls),
                argument_name + f".{i}",
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )
            value_list.append(value)

        return value_list

    elif origin in (Tuple, tuple) and all(can_construct_from_params(arg) for arg in args):
        value_list = []

        for i, (value_cls, value_params) in enumerate(zip(annotation.__args__, popped_params)):
            value = construct_arg(
                str(value_cls),
                argument_name + f".{i}",
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )
            value_list.append(value)

        return tuple(value_list)

    elif origin in (Set, set) and len(args) == 1 and can_construct_from_params(args[0]):
        value_cls = annotation.__args__[0]

        value_set = set()

        for i, value_params in enumerate(popped_params):
            value = construct_arg(
                str(value_cls),
                argument_name + f".{i}",
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )
            value_set.add(value)

        return value_set

    elif origin == Union:
        # Storing this so we can recover it later if we need to.
        backup_params = deepcopy(popped_params)

        # We'll try each of the given types in the union sequentially, returning the first one that
        # succeeds.
        for arg_annotation in args:
            try:
                return construct_arg(
                    str(arg_annotation),
                    argument_name,
                    popped_params,
                    arg_annotation,
                    default,
                    **extras,
                )
            except (ValueError, TypeError, ConfigurationError, AttributeError):
                # Our attempt to construct the argument may have modified popped_params, so we
                # restore it here.
                popped_params = deepcopy(backup_params)

        # If none of them succeeded, we crash.
        raise ConfigurationError(
            f"Failed to construct argument {argument_name} with type {annotation}"
        )
    elif origin == Lazy:
        if popped_params is default:
            return Lazy(lambda **kwargs: default)
        value_cls = args[0]
        subextras = create_extras(value_cls, extras)

        def constructor(**kwargs):
            return value_cls.from_params(params=popped_params, **kwargs, **subextras)

        return Lazy(constructor)  # type: ignore
    else:
        # Pass it on as is and hope for the best.   ¯\_(ツ)_/¯
        if isinstance(popped_params, Params):
            return popped_params.as_dict(quiet=True)
        return popped_params
Exemple #16
0
def construct_arg(
    class_name: str,
    argument_name: str,
    popped_params: Params,
    annotation: Type,
    default: Any,
    could_be_step: bool = True,
    **extras,
) -> Any:
    """
    The first two parameters here are only used for logging if we encounter an error.
    """
    from allennlp.tango.step import Step, _RefStep

    if could_be_step:
        # We try parsing as a step _first_. Parsing as a non-step always succeeds, because
        # it will fall back to returning a dict. So we can't try parsing as a non-step first.
        backup_params = deepcopy(popped_params)
        try:
            return construct_arg(
                class_name,
                argument_name,
                popped_params,
                Step[annotation],  # type: ignore
                default,
                could_be_step=False,
                **extras,
            )
        except (ValueError, TypeError, ConfigurationError, AttributeError):
            popped_params = backup_params

    origin = getattr(annotation, "__origin__", None)
    args = getattr(annotation, "__args__", [])

    # The parameter is optional if its default value is not the "no default" sentinel.
    optional = default != _NO_DEFAULT

    if hasattr(annotation, "from_params"):
        if popped_params is default:
            return default
        elif popped_params is not None:
            # Our params have an entry for this, so we use that.

            subextras = create_extras(annotation, extras)

            # In some cases we allow a string instead of a param dict, so
            # we need to handle that case separately.
            if isinstance(popped_params, str):
                if origin != Step:
                    # We don't allow single strings to be upgraded to steps.
                    # Since we try everything as a step first, upgrading strings to
                    # steps automatically would cause confusion every time a step
                    # name conflicts with any string anywhere in a config.
                    popped_params = Params({"type": popped_params})
            elif isinstance(popped_params, dict):
                popped_params = Params(popped_params)
            result = annotation.from_params(params=popped_params, **subextras)

            if isinstance(result, Step):
                if isinstance(result, _RefStep):
                    existing_steps: Dict[str, Step] = extras.get(
                        "existing_steps", {})
                    try:
                        result = existing_steps[result.ref()]
                    except KeyError:
                        raise _RefStep.MissingStepError(result.ref())

                expected_return_type = args[0]
                return_type = inspect.signature(result.run).return_annotation
                if return_type == inspect.Signature.empty:
                    logger.warning(
                        "Step %s has no return type annotation. Those are really helpful when "
                        "debugging, so we recommend them highly.",
                        result.__class__.__name__,
                    )
                elif not issubclass(return_type, expected_return_type):
                    raise ConfigurationError(
                        f"Step {result.name} returns {return_type}, but "
                        f"we expected {expected_return_type}.")

            return result
        elif not optional:
            # Not optional and not supplied, that's an error!
            raise ConfigurationError(
                f"expected key {argument_name} for {class_name}")
        else:
            return default

    # If the parameter type is a Python primitive, just pop it off
    # using the correct casting pop_xyz operation.
    elif annotation in {int, bool}:
        if type(popped_params) in {int, bool}:
            return annotation(popped_params)
        else:
            raise TypeError(
                f"Expected {argument_name} to be a {annotation.__name__}.")
    elif annotation == str:
        # Strings are special because we allow casting from Path to str.
        if type(popped_params) == str or isinstance(popped_params, Path):
            return str(popped_params)  # type: ignore
        else:
            raise TypeError(f"Expected {argument_name} to be a string.")
    elif annotation == float:
        # Floats are special because in Python, you can put an int wherever you can put a float.
        # https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html
        if type(popped_params) in {int, float}:
            return popped_params
        else:
            raise TypeError(f"Expected {argument_name} to be numeric.")

    # This is special logic for handling types like Dict[str, TokenIndexer],
    # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer],
    # which it creates by instantiating each value from_params and returning the resulting structure.
    elif (origin in {collections.abc.Mapping, Mapping, Dict, dict}
          and len(args) == 2 and can_construct_from_params(args[-1])):
        value_cls = annotation.__args__[-1]
        value_dict = {}
        if not isinstance(popped_params, Mapping):
            raise TypeError(
                f"Expected {argument_name} to be a Mapping (probably a dict or a Params object)."
            )

        for key, value_params in popped_params.items():
            value_dict[key] = construct_arg(
                str(value_cls),
                argument_name + "." + key,
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )

        return value_dict

    elif origin in (Tuple, tuple) and all(
            can_construct_from_params(arg) for arg in args):
        value_list = []

        for i, (value_cls, value_params) in enumerate(
                zip(annotation.__args__, popped_params)):
            value = construct_arg(
                str(value_cls),
                argument_name + f".{i}",
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )
            value_list.append(value)

        return tuple(value_list)

    elif origin in (Set, set) and len(args) == 1 and can_construct_from_params(
            args[0]):
        value_cls = annotation.__args__[0]

        value_set = set()

        for i, value_params in enumerate(popped_params):
            value = construct_arg(
                str(value_cls),
                argument_name + f".{i}",
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )
            value_set.add(value)

        return value_set

    elif origin == Union:
        # Storing this so we can recover it later if we need to.
        backup_params = deepcopy(popped_params)

        # We'll try each of the given types in the union sequentially, returning the first one that
        # succeeds.
        error_chain: Optional[Exception] = None
        for arg_annotation in args:
            try:
                return construct_arg(
                    str(arg_annotation),
                    argument_name,
                    popped_params,
                    arg_annotation,
                    default,
                    **extras,
                )
            except (ValueError, TypeError, ConfigurationError,
                    AttributeError) as e:
                # Our attempt to construct the argument may have modified popped_params, so we
                # restore it here.
                popped_params = deepcopy(backup_params)
                e.args = (
                    f"While constructing an argument of type {arg_annotation}",
                ) + e.args
                e.__cause__ = error_chain
                error_chain = e

        # If none of them succeeded, we crash.
        config_error = ConfigurationError(
            f"Failed to construct argument {argument_name} with type {annotation}."
        )
        config_error.__cause__ = error_chain
        raise config_error
    elif origin == Lazy:
        if popped_params is default:
            return default

        value_cls = args[0]
        subextras = create_extras(value_cls, extras)
        return Lazy(value_cls,
                    params=deepcopy(popped_params),
                    constructor_extras=subextras)  # type: ignore

    # For any other kind of iterable, we will just assume that a list is good enough, and treat
    # it the same as List. This condition needs to be at the end, so we don't catch other kinds
    # of Iterables with this branch.
    elif (origin in {collections.abc.Iterable, Iterable, List, list}
          and len(args) == 1 and can_construct_from_params(args[0])):
        value_cls = annotation.__args__[0]

        value_list = []

        for i, value_params in enumerate(popped_params):
            value = construct_arg(
                str(value_cls),
                argument_name + f".{i}",
                value_params,
                value_cls,
                _NO_DEFAULT,
                **extras,
            )
            value_list.append(value)

        return value_list

    else:
        # Pass it on as is and hope for the best.   ¯\_(ツ)_/¯
        if isinstance(popped_params, Params):
            return popped_params.as_dict()
        return popped_params