def test_train_read(self): self.reader = Flickr30kReader( image_dir=FIXTURES_ROOT / "vision" / "images" / "flickr30k", image_loader=TorchImageLoader(), image_featurizer=Lazy(NullGridEmbedder), data_dir=FIXTURES_ROOT / "vision" / "flickr30k" / "sentences", region_detector=Lazy(RandomRegionDetector), tokenizer=WhitespaceTokenizer(), token_indexers={"tokens": SingleIdTokenIndexer()}, featurize_captions=False, num_potential_hard_negatives=4, ) instances = list(self.reader.read("test_fixtures/vision/flickr30k/test.txt")) assert len(instances) == 25 instance = instances[5] assert len(instance.fields) == 5 assert len(instance["caption"]) == 4 assert len(instance["caption"][0]) == 12 # 16 assert instance["caption"][0] != instance["caption"][1] assert instance["caption"][0] == instance["caption"][2] assert instance["caption"][0] == instance["caption"][3] question_tokens = [t.text for t in instance["caption"][0]] assert question_tokens == [ "girl", "with", "brown", "hair", "sits", "on", "edge", "of", "concrete", "area", "overlooking", "water", ] batch = Batch(instances) batch.index_instances(Vocabulary()) tensors = batch.as_tensor_dict() # (batch size, num images (3 hard negatives + gold image), num boxes (fake), num features (fake)) assert tensors["box_features"].size() == (25, 4, 2, 10) # (batch size, num images (3 hard negatives + gold image), num boxes (fake), 4 coords) assert tensors["box_coordinates"].size() == (25, 4, 2, 4) # (batch size, num images (3 hard negatives + gold image), num boxes (fake),) assert tensors["box_mask"].size() == (25, 4, 2) # (batch size) assert tensors["label"].size() == (25,)
def setup_method(self): from allennlp_models.vision.dataset_readers.gqa import GQAReader super().setup_method() self.reader = GQAReader( image_dir=FIXTURES_ROOT / "vision" / "images" / "gqa", image_loader=TorchImageLoader(), image_featurizer=Lazy(NullGridEmbedder), region_detector=Lazy(RandomRegionDetector), tokenizer=WhitespaceTokenizer(), token_indexers={"tokens": SingleIdTokenIndexer()}, )
def __init__(self, vocab: Vocabulary, model_name: str, beam_search: Lazy[BeamSearch] = Lazy(BeamSearch, beam_size=3, max_steps=50), checkpoint_wrapper: Optional[CheckpointWrapper] = None, weights_path: Optional[Union[str, PathLike]] = None, **kwargs) -> None: super().__init__(vocab, **kwargs) self._model_name = model_name # We only instantiate this when we need it. self._tokenizer: Optional[PretrainedTransformerTokenizer] = None self.t5 = T5Module.from_pretrained_module( model_name, beam_search=beam_search, ddp_accelerator=self.ddp_accelerator, checkpoint_wrapper=checkpoint_wrapper, weights_path=weights_path, ) exclude_indices = { self.t5.pad_token_id, self.t5.decoder_start_token_id, self.t5.eos_token_id, } self._metrics = [ ROUGE(exclude_indices=exclude_indices), BLEU(exclude_indices=exclude_indices), ]
def construct_from_params(value_cls: Type[T], value_params: Params, extras: Dict[str, Any]) -> T: """ At this point we know that we need to use `from_params` to construct an object that will be (part of) an argument to a constructor. This does the logic of actually calling that `from_params` method. This is normally as simple as just `value_cls.from_params`, but we first call `create_extras` to pass along any **kwargs that we got as input, and we also have some special handling for `Lazy` annotations here - we don't want to recurse on `Lazy.from_params`, we want to bypass that. """ origin = getattr(value_cls, "__origin__", None) if origin == Lazy: value_cls = value_cls.__args__[0] # type: ignore subextras = create_extras(value_cls, extras) def constructor(**kwargs): return value_cls.from_params(params=value_params, **kwargs, **subextras) return Lazy(constructor) # type: ignore else: subextras = create_extras(value_cls, extras) return value_cls.from_params(params=value_params, **subextras)
def test_read(self): from allennlp_models.vision.dataset_readers.vqav2 import VQAv2Reader reader = VQAv2Reader( image_dir=FIXTURES_ROOT / "vision" / "images" / "vqav2", image_loader=TorchImageLoader(), image_featurizer=Lazy(NullGridEmbedder), region_detector=Lazy(RandomRegionDetector), tokenizer=WhitespaceTokenizer(), token_indexers={"tokens": SingleIdTokenIndexer()}, ) instances = list(reader.read("unittest")) assert len(instances) == 3 instance = instances[0] assert len(instance.fields) == 6 assert len(instance["question"]) == 7 question_tokens = [t.text for t in instance["question"]] assert question_tokens == [ "What", "is", "this", "photo", "taken", "looking", "through?" ] assert len(instance["labels"]) == 5 labels = [field.label for field in instance["labels"].field_list] assert labels == ["net", "netting", "mesh", "pitcher", "orange"] assert torch.allclose( instance["label_weights"].tensor, torch.tensor([1.0, 1.0 / 3, 1.0 / 3, 1.0 / 3, 1.0 / 3]), ) batch = Batch(instances) batch.index_instances(Vocabulary()) tensors = batch.as_tensor_dict() # (batch size, num boxes (fake), num features (fake)) assert tensors["box_features"].size() == (3, 2, 10) # (batch size, num boxes (fake), 4 coords) assert tensors["box_coordinates"].size() == (3, 2, 4) # (batch size, num boxes (fake),) assert tensors["box_mask"].size() == (3, 2) # Nothing should be masked out since the number of fake boxes is the same # for each item in the batch. assert tensors["box_mask"].all()
def pop_and_construct_arg( class_name: str, argument_name: str, annotation: Type, default: Any, params: Params, **extras ) -> Any: """ Does the work of actually constructing an individual argument for [`create_kwargs`](./from_params#create_kwargs). Here we're in the inner loop of iterating over the parameters to a particular constructor, trying to construct just one of them. The information we get for that parameter is its name, its type annotation, and its default value; we also get the full set of `Params` for constructing the object (which we may mutate), and any `extras` that the constructor might need. We take the type annotation and default value here separately, instead of using an `inspect.Parameter` object directly, so that we can handle `Union` types using recursion on this method, trying the different annotation types in the union in turn. """ from allennlp.models.archival import load_archive # import here to avoid circular imports # We used `argument_name` as the method argument to avoid conflicts with 'name' being a key in # `extras`, which isn't _that_ unlikely. Now that we are inside the method, we can switch back # to using `name`. name = argument_name # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary. # We check the provided `extras` for these and just use them if they exist. if name in extras: return extras[name] # Next case is when argument should be loaded from pretrained archive. elif ( name in params and isinstance(params.get(name), Params) and "_pretrained" in params.get(name) ): load_module_params = params.pop(name).pop("_pretrained") archive_file = load_module_params.pop("archive_file") module_path = load_module_params.pop("module_path") freeze = load_module_params.pop("freeze", True) archive = load_archive(archive_file) result = archive.extract_module(module_path, freeze) if not isinstance(result, annotation): raise ConfigurationError( f"The module from model at {archive_file} at path {module_path} " f"was expected of type {annotation} but is of type {type(result)}" ) return result popped_params = params.pop(name, default) if default != _NO_DEFAULT else params.pop(name) if popped_params is None: origin = getattr(annotation, "__origin__", None) if origin == Lazy: return Lazy(lambda **kwargs: None) return None return construct_arg(class_name, name, popped_params, annotation, default, **extras)
def test_read(self): from allennlp_models.vision.dataset_readers.vgqa import VGQAReader reader = VGQAReader( image_dir=FIXTURES_ROOT / "vision" / "images" / "vgqa", image_loader=TorchImageLoader(), image_featurizer=Lazy(NullGridEmbedder), region_detector=Lazy(RandomRegionDetector), tokenizer=WhitespaceTokenizer(), token_indexers={"tokens": SingleIdTokenIndexer()}, ) instances = list( reader.read("test_fixtures/vision/vgqa/question_answers.json")) assert len(instances) == 8 instance = instances[0] assert len(instance.fields) == 6 assert len(instance["question"]) == 5 question_tokens = [t.text for t in instance["question"]] assert question_tokens == ["What", "is", "on", "the", "curtains?"] assert len(instance["labels"]) == 1 labels = [field.label for field in instance["labels"].field_list] assert labels == ["sailboats"] batch = Batch(instances) batch.index_instances(Vocabulary()) tensors = batch.as_tensor_dict() # (batch size, num boxes (fake), num features (fake)) assert tensors["box_features"].size() == (8, 2, 10) # (batch size, num boxes (fake), 4 coords) assert tensors["box_coordinates"].size() == (8, 2, 4) # (batch size, num boxes (fake)) assert tensors["box_mask"].size() == (8, 2) # Nothing should be masked out since the number of fake boxes is the same # for each item in the batch. assert tensors["box_mask"].all()
def test_read(self): from allennlp_models.vision.dataset_readers.nlvr2 import Nlvr2Reader reader = Nlvr2Reader( image_dir=FIXTURES_ROOT / "vision" / "images" / "nlvr2", image_loader=TorchImageLoader(), image_featurizer=Lazy(NullGridEmbedder), region_detector=Lazy(RandomRegionDetector), tokenizer=WhitespaceTokenizer(), token_indexers={"tokens": SingleIdTokenIndexer()}, ) instances = list( reader.read("test_fixtures/vision/nlvr2/tiny-dev.json")) assert len(instances) == 8 instance = instances[0] assert len(instance.fields) == 6 assert instance["hypothesis"][0] == instance["hypothesis"][1] assert len(instance["hypothesis"][0]) == 18 hypothesis_tokens = [t.text for t in instance["hypothesis"][0]] assert hypothesis_tokens[:6] == [ "The", "right", "image", "shows", "a", "curving" ] assert instance["label"].label == 0 assert instances[1]["label"].label == 1 assert instance["identifier"].metadata == "dev-850-0-0" batch = Batch(instances) batch.index_instances(Vocabulary()) tensors = batch.as_tensor_dict() # (batch size, 2 images per instance, num boxes (fake), num features (fake)) assert tensors["box_features"].size() == (8, 2, 2, 10) # (batch size, 2 images per instance, num boxes (fake), 4 coords) assert tensors["box_coordinates"].size() == (8, 2, 2, 4) # (batch size, 2 images per instance, num boxes (fake)) assert tensors["box_mask"].size() == (8, 2, 2)
def test_read(self): from allennlp_models.vision.dataset_readers.visual_entailment import VisualEntailmentReader reader = VisualEntailmentReader( image_dir=FIXTURES_ROOT / "vision" / "images" / "visual_entailment", image_loader=TorchImageLoader(), image_featurizer=Lazy(NullGridEmbedder), region_detector=Lazy(RandomRegionDetector), tokenizer=WhitespaceTokenizer(), token_indexers={"tokens": SingleIdTokenIndexer()}, ) instances = list( reader.read( "test_fixtures/vision/visual_entailment/sample_pairs.jsonl")) assert len(instances) == 16 instance = instances[0] assert len(instance.fields) == 5 assert len(instance["hypothesis"]) == 4 sentence_tokens = [t.text for t in instance["hypothesis"]] assert sentence_tokens == ["A", "toddler", "sleeps", "outside."] assert instance["labels"].label == "contradiction" batch = Batch(instances) vocab = Vocabulary() vocab.add_tokens_to_namespace( ["entailment", "contradiction", "neutral"], "labels") batch.index_instances(vocab) tensors = batch.as_tensor_dict() # (batch size, num boxes (fake), num features (fake)) assert tensors["box_features"].size() == (16, 2, 10) # (batch size, num boxes (fake), 4 coords) assert tensors["box_coordinates"].size() == (16, 2, 4) # (batch_size, num boxes (fake),) assert tensors["box_mask"].size() == (16, 2)
def __init__( self, model_name: str, vocab: Vocabulary, beam_search: Lazy[BeamSearch] = Lazy(BeamSearch), indexer: PretrainedTransformerIndexer = None, encoder: Seq2SeqEncoder = None, **kwargs, ): super().__init__(vocab) self.bart = BartForConditionalGeneration.from_pretrained(model_name) self._indexer = indexer or PretrainedTransformerIndexer(model_name, namespace="tokens") self._start_id = self.bart.config.bos_token_id # CLS self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id self._end_id = self.bart.config.eos_token_id # SEP self._pad_id = self.bart.config.pad_token_id # PAD # At prediction time, we'll use a beam search to find the best target sequence. # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning deprecation_warning = ( "The parameter {} has been deprecated." " Provide this parameter as argument to beam_search instead." ) beam_search_extras = {} if "beam_size" in kwargs: beam_search_extras["beam_size"] = kwargs["beam_size"] warnings.warn(deprecation_warning.format("beam_size"), DeprecationWarning) if "max_decoding_steps" in kwargs: beam_search_extras["max_steps"] = kwargs["max_decoding_steps"] warnings.warn(deprecation_warning.format("max_decoding_steps"), DeprecationWarning) self._beam_search = beam_search.construct( end_index=self._end_id, vocab=self.vocab, **beam_search_extras ) self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id}) self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id}) # Replace bart encoder with given encoder. We need to extract the two embedding layers so that # we can use them in the encoder wrapper if encoder is not None: assert ( encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size ) self.bart.model.encoder = _BartEncoderWrapper( encoder, self.bart.model.encoder.embed_tokens, self.bart.model.encoder.embed_positions, )
def __init__(self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, beam_search: Lazy[BeamSearch] = Lazy(BeamSearch), attention: Attention = None, target_namespace: str = "tokens", target_embedding_dim: int = None, scheduled_sampling_ratio: float = 0.0, use_bleu: bool = True, bleu_ngram_weights: Iterable[float] = (0.25, 0.25, 0.25, 0.25), target_pretrain_file: str = None, target_decoder_layers: int = 1, **kwargs) -> None: super().__init__(vocab) self._target_namespace = target_namespace self._target_decoder_layers = target_decoder_layers self._scheduled_sampling_ratio = scheduled_sampling_ratio # We need the start symbol to provide as the input at the first timestep of decoding, and # end symbol as a way to indicate the end of the decoded sequence. self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) if use_bleu: pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) self._bleu = BLEU( bleu_ngram_weights, exclude_indices={ pad_index, self._end_index, self._start_index }, ) else: self._bleu = None # At prediction time, we'll use a beam search to find the best target sequence. # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning deprecation_warning = ( "The parameter {} has been deprecated." " Provide this parameter as argument to beam_search instead.") beam_search_extras = {} if "beam_size" in kwargs: beam_search_extras["beam_size"] = kwargs["beam_size"] warnings.warn(deprecation_warning.format("beam_size"), DeprecationWarning) if "max_decoding_steps" in kwargs: beam_search_extras["max_steps"] = kwargs["max_decoding_steps"] warnings.warn(deprecation_warning.format("max_decoding_steps"), DeprecationWarning) self._beam_search = beam_search.construct(end_index=self._end_index, vocab=self.vocab, **beam_search_extras) # Dense embedding of source vocab tokens. self._source_embedder = source_embedder # Encodes the sequence of source embeddings into a sequence of hidden states. self._encoder = encoder num_classes = self.vocab.get_vocab_size(self._target_namespace) # Attention mechanism applied to the encoder output for each step. self._attention = attention # Dense embedding of vocab words in the target space. target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim( ) if not target_pretrain_file: self._target_embedder = Embedding( num_embeddings=num_classes, embedding_dim=target_embedding_dim) else: self._target_embedder = Embedding( embedding_dim=target_embedding_dim, pretrained_file=target_pretrain_file, vocab_namespace=self._target_namespace, vocab=self.vocab, ) # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. self._encoder_output_dim = self._encoder.get_output_dim() self._decoder_output_dim = self._encoder_output_dim if self._attention: # If using attention, a weighted average over encoder outputs will be concatenated # to the previous target embedding to form the input to the decoder at each # time step. self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim else: # Otherwise, the input to the decoder is just the previous target embedding. self._decoder_input_dim = target_embedding_dim # We'll use an LSTM cell as the recurrent cell that produces a hidden state # for the decoder at each time step. # TODO (pradeep): Do not hardcode decoder cell type. if self._target_decoder_layers > 1: self._decoder_cell = LSTM( self._decoder_input_dim, self._decoder_output_dim, self._target_decoder_layers, ) else: self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) # We project the hidden state from the decoder into the output vocabulary space # in order to get log probabilities of each target token, at each time step. self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)
def construct_arg( class_name: str, argument_name: str, popped_params: Params, annotation: Type, default: Any, **extras, ) -> Any: """ The first two parameters here are only used for logging if we encounter an error. """ origin = getattr(annotation, "__origin__", None) args = getattr(annotation, "__args__", []) # The parameter is optional if its default value is not the "no default" sentinel. optional = default != _NO_DEFAULT if hasattr(annotation, "from_params"): if popped_params is default: return default elif popped_params is not None: # Our params have an entry for this, so we use that. subextras = create_extras(annotation, extras) # In some cases we allow a string instead of a param dict, so # we need to handle that case separately. if isinstance(popped_params, str): return annotation.by_name(popped_params)() else: if isinstance(popped_params, dict): popped_params = Params(popped_params) return annotation.from_params(params=popped_params, **subextras) elif not optional: # Not optional and not supplied, that's an error! raise ConfigurationError( f"expected key {argument_name} for {class_name}") else: return default # If the parameter type is a Python primitive, just pop it off # using the correct casting pop_xyz operation. elif annotation in {int, bool}: if type(popped_params) in {int, bool}: return annotation(popped_params) else: raise TypeError( f"Expected {argument_name} to be a {annotation.__name__}.") elif annotation == str: # Strings are special because we allow casting from Path to str. if type(popped_params) == str or isinstance(popped_params, Path): return str(popped_params) # type: ignore else: raise TypeError(f"Expected {argument_name} to be a string.") elif annotation == float: # Floats are special because in Python, you can put an int wherever you can put a float. # https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html if type(popped_params) in {int, float}: return popped_params else: raise TypeError(f"Expected {argument_name} to be numeric.") # This is special logic for handling types like Dict[str, TokenIndexer], # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer], # which it creates by instantiating each value from_params and returning the resulting structure. elif origin in (Dict, dict) and len(args) == 2 and can_construct_from_params( args[-1]): value_cls = annotation.__args__[-1] value_dict = {} for key, value_params in popped_params.items(): value_dict[key] = construct_arg( str(value_cls), argument_name + "." + key, value_params, value_cls, _NO_DEFAULT, **extras, ) return value_dict elif origin in (List, list) and len(args) == 1 and can_construct_from_params( args[0]): value_cls = annotation.__args__[0] value_list = [] for i, value_params in enumerate(popped_params): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_list.append(value) return value_list elif origin in (Tuple, tuple) and all( can_construct_from_params(arg) for arg in args): value_list = [] for i, (value_cls, value_params) in enumerate( zip(annotation.__args__, popped_params)): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_list.append(value) return tuple(value_list) elif origin in (Set, set) and len(args) == 1 and can_construct_from_params( args[0]): value_cls = annotation.__args__[0] value_set = set() for i, value_params in enumerate(popped_params): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_set.add(value) return value_set elif origin == Union: # Storing this so we can recover it later if we need to. backup_params = deepcopy(popped_params) # We'll try each of the given types in the union sequentially, returning the first one that # succeeds. for arg_annotation in args: try: return construct_arg( str(arg_annotation), argument_name, popped_params, arg_annotation, default, **extras, ) except (ValueError, TypeError, ConfigurationError, AttributeError): # Our attempt to construct the argument may have modified popped_params, so we # restore it here. popped_params = deepcopy(backup_params) # If none of them succeeded, we crash. raise ConfigurationError( f"Failed to construct argument {argument_name} with type {annotation}" ) elif origin == Lazy: if popped_params is default: return Lazy(lambda **kwargs: default) value_cls = args[0] subextras = create_extras(value_cls, extras) def constructor(**kwargs): # If there are duplicate keys between subextras and kwargs, this will overwrite the ones # in subextras with what's in kwargs. If an argument shows up twice, we should take it # from what's passed to Lazy.construct() instead of what we got from create_extras(). # Almost certainly these will be identical objects, anyway. # We do this by constructing a new dictionary, instead of mutating subextras, just in # case this constructor is called multiple times. constructor_extras = {**subextras, **kwargs} return value_cls.from_params(params=deepcopy(popped_params), **constructor_extras) return Lazy(constructor) # type: ignore else: # Pass it on as is and hope for the best. ¯\_(ツ)_/¯ if isinstance(popped_params, Params): return popped_params.as_dict(quiet=True) return popped_params
def construct_arg(class_name: str, param_name: str, annotation: Type, default: Any, params: Params, **extras) -> Any: """ Does the work of actually constructing an individual argument for :func:`create_kwargs`. Here we're in the inner loop of iterating over the parameters to a particular constructor, trying to construct just one of them. The information we get for that parameter is its name, its type annotation, and its default value; we also get the full set of `Params` for constructing the object (which we may mutate), and any `extras` that the constructor might need. We take the type annotation and default value here separately, instead of using an `inspect.Parameter` object directly, so that we can handle `Union` types using recursion on this method, trying the different annotation types in the union in turn. """ from allennlp.models.archival import load_archive # import here to avoid circular imports # We used `param_name` as the method argument to avoid conflicts with 'name' being a key in # `extras`, which isn't _that_ unlikely. Now that we are inside the method, we can switch back # to using `name`. name = param_name origin = getattr(annotation, "__origin__", None) args = getattr(annotation, "__args__", []) # The parameter is optional if its default value is not the "no default" sentinel. optional = default != _NO_DEFAULT # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary. # We check the provided `extras` for these and just use them if they exist. if name in extras: return extras[name] # Next case is when argument should be loaded from pretrained archive. elif (name in params and isinstance(params.get(name), Params) and "_pretrained" in params.get(name)): load_module_params = params.pop(name).pop("_pretrained") archive_file = load_module_params.pop("archive_file") module_path = load_module_params.pop("module_path") freeze = load_module_params.pop("freeze", True) archive = load_archive(archive_file) result = archive.extract_module(module_path, freeze) if not isinstance(result, annotation): raise ConfigurationError( f"The module from model at {archive_file} at path {module_path} " f"was expected of type {annotation} but is of type {type(result)}" ) return result # The next case is when the parameter type is itself constructible from_params. elif hasattr(annotation, "from_params"): if name in params: # Our params have an entry for this, so we use that. subparams = params.pop(name) subextras = create_extras(annotation, extras) # In some cases we allow a string instead of a param dict, so # we need to handle that case separately. if isinstance(subparams, str): return annotation.by_name(subparams)() else: return annotation.from_params(params=subparams, **subextras) elif not optional: # Not optional and not supplied, that's an error! raise ConfigurationError(f"expected key {name} for {class_name}") else: return default # If the parameter type is a Python primitive, just pop it off # using the correct casting pop_xyz operation. elif annotation == str: return params.pop(name, default) if optional else params.pop(name) elif annotation == int: return params.pop_int(name, default) if optional else params.pop_int(name) elif annotation == bool: return params.pop_bool(name, default) if optional else params.pop_bool(name) elif annotation == float: return params.pop_float( name, default) if optional else params.pop_float(name) # This is special logic for handling types like Dict[str, TokenIndexer], # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer], # which it creates by instantiating each value from_params and returning the resulting structure. elif origin in (Dict, dict) and len(args) == 2 and can_construct_from_params( args[-1]): value_cls = annotation.__args__[-1] value_dict = {} for key, value_params in params.pop(name, Params({})).items(): value_dict[key] = construct_from_params(value_cls, value_params, extras) return value_dict elif origin in (List, list) and len(args) == 1 and can_construct_from_params( args[0]): value_cls = annotation.__args__[0] value_list = [] for value_params in params.pop(name, Params({})): value_list.append( construct_from_params(value_cls, value_params, extras)) return value_list elif origin in (Tuple, tuple) and all( can_construct_from_params(arg) for arg in args): value_list = [] for value_cls, value_params in zip(annotation.__args__, params.pop(name, Params({}))): value_list.append( construct_from_params(value_cls, value_params, extras)) return tuple(value_list) elif origin in (Set, set) and len(args) == 1 and can_construct_from_params( args[0]): value_cls = annotation.__args__[0] value_set = set() for value_params in params.pop(name, Params({})): value_set.add( construct_from_params(value_cls, value_params, extras)) return value_set elif origin == Union: # Storing this so we can recover it later if we need to. param_value = params.pop(name, default=default, keep_as_dict=True) params[name] = deepcopy(param_value) # We'll try each of the given types in the union sequentially, returning the first one that # succeeds. for arg in args: try: return construct_arg(class_name, name, arg, default, params, **extras) except (ValueError, TypeError, ConfigurationError, AttributeError): # Our attempt to construct the argument may have popped `params[name]`, so we # restore it here. params[name] = param_value param_value = deepcopy(param_value) continue # If none of them succeeded, we crash. raise ConfigurationError( f"Failed to construct argument {name} with type {annotation}") elif origin == Lazy: if name not in params and optional: return Lazy(lambda **kwargs: default) value_params = params.pop(name, Params({})) return construct_from_params(annotation, value_params, extras) else: # Pass it on as is and hope for the best. ¯\_(ツ)_/¯ if optional: value = params.pop(name, default, keep_as_dict=True) else: value = params.pop(name, keep_as_dict=True) return value
def construct_arg( class_name: str, argument_name: str, popped_params: Params, annotation: Type, default: Any, **extras, ) -> Any: """ The first two parameters here are only used for logging if we encounter an error. """ origin = getattr(annotation, "__origin__", None) args = getattr(annotation, "__args__", []) # The parameter is optional if its default value is not the "no default" sentinel. optional = default != _NO_DEFAULT if hasattr(annotation, "from_params"): if popped_params is default: return default elif popped_params is not None: # Our params have an entry for this, so we use that. subextras = create_extras(annotation, extras) # In some cases we allow a string instead of a param dict, so # we need to handle that case separately. if isinstance(popped_params, str): popped_params = Params({"type": popped_params}) elif isinstance(popped_params, dict): popped_params = Params(popped_params) return annotation.from_params(params=popped_params, **subextras) elif not optional: # Not optional and not supplied, that's an error! raise ConfigurationError( f"expected key {argument_name} for {class_name}") else: return default # If the parameter type is a Python primitive, just pop it off # using the correct casting pop_xyz operation. elif annotation in {int, bool}: if type(popped_params) in {int, bool}: return annotation(popped_params) else: raise TypeError( f"Expected {argument_name} to be a {annotation.__name__}.") elif annotation == str: # Strings are special because we allow casting from Path to str. if type(popped_params) == str or isinstance(popped_params, Path): return str(popped_params) # type: ignore else: raise TypeError(f"Expected {argument_name} to be a string.") elif annotation == float: # Floats are special because in Python, you can put an int wherever you can put a float. # https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html if type(popped_params) in {int, float}: return popped_params else: raise TypeError(f"Expected {argument_name} to be numeric.") # This is special logic for handling types like Dict[str, TokenIndexer], # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer], # which it creates by instantiating each value from_params and returning the resulting structure. elif (origin in {collections.abc.Mapping, Mapping, Dict, dict} and len(args) == 2 and can_construct_from_params(args[-1])): value_cls = annotation.__args__[-1] value_dict = {} if not isinstance(popped_params, Mapping): raise TypeError( f"Expected {argument_name} to be a Mapping (probably a dict or a Params object)." ) for key, value_params in popped_params.items(): value_dict[key] = construct_arg( str(value_cls), argument_name + "." + key, value_params, value_cls, _NO_DEFAULT, **extras, ) return value_dict elif origin in (Tuple, tuple) and all( can_construct_from_params(arg) for arg in args): value_list = [] for i, (value_cls, value_params) in enumerate( zip(annotation.__args__, popped_params)): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_list.append(value) return tuple(value_list) elif origin in (Set, set) and len(args) == 1 and can_construct_from_params( args[0]): value_cls = annotation.__args__[0] value_set = set() for i, value_params in enumerate(popped_params): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_set.add(value) return value_set elif origin == Union: # Storing this so we can recover it later if we need to. backup_params = deepcopy(popped_params) # We'll try each of the given types in the union sequentially, returning the first one that # succeeds. for arg_annotation in args: try: return construct_arg( str(arg_annotation), argument_name, popped_params, arg_annotation, default, **extras, ) except (ValueError, TypeError, ConfigurationError, AttributeError): # Our attempt to construct the argument may have modified popped_params, so we # restore it here. popped_params = deepcopy(backup_params) # If none of them succeeded, we crash. raise ConfigurationError( f"Failed to construct argument {argument_name} with type {annotation}" ) elif origin == Lazy: if popped_params is default: return default value_cls = args[0] subextras = create_extras(value_cls, extras) return Lazy(value_cls, params=deepcopy(popped_params), contructor_extras=subextras) # type: ignore # For any other kind of iterable, we will just assume that a list is good enough, and treat # it the same as List. This condition needs to be at the end, so we don't catch other kinds # of Iterables with this branch. elif (origin in {collections.abc.Iterable, Iterable, List, list} and len(args) == 1 and can_construct_from_params(args[0])): value_cls = annotation.__args__[0] value_list = [] for i, value_params in enumerate(popped_params): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_list.append(value) return value_list else: # Pass it on as is and hope for the best. ¯\_(ツ)_/¯ if isinstance(popped_params, Params): return popped_params.as_dict() return popped_params
def construct_arg( class_name: str, argument_name: str, popped_params: Params, annotation: Type, default: Any, **extras, ) -> Any: """ The first two parameters here are only used for logging if we encounter an error. """ origin = getattr(annotation, "__origin__", None) args = getattr(annotation, "__args__", []) # The parameter is optional if its default value is not the "no default" sentinel. optional = default != _NO_DEFAULT if hasattr(annotation, "from_params"): if popped_params is default: return default elif popped_params is not None: # Our params have an entry for this, so we use that. subextras = create_extras(annotation, extras) # In some cases we allow a string instead of a param dict, so # we need to handle that case separately. if isinstance(popped_params, str): return annotation.by_name(popped_params)() else: if isinstance(popped_params, dict): popped_params = Params(popped_params) return annotation.from_params(params=popped_params, **subextras) elif not optional: # Not optional and not supplied, that's an error! raise ConfigurationError(f"expected key {argument_name} for {class_name}") else: return default # If the parameter type is a Python primitive, just pop it off # using the correct casting pop_xyz operation. elif annotation == str: return popped_params elif annotation == int: return int(popped_params) # type: ignore elif annotation == bool: return bool(popped_params) elif annotation == float: return float(popped_params) # type: ignore # This is special logic for handling types like Dict[str, TokenIndexer], # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer], # which it creates by instantiating each value from_params and returning the resulting structure. elif origin in (Dict, dict) and len(args) == 2 and can_construct_from_params(args[-1]): value_cls = annotation.__args__[-1] value_dict = {} for key, value_params in popped_params.items(): value_dict[key] = construct_arg( str(value_cls), argument_name + "." + key, value_params, value_cls, _NO_DEFAULT, **extras, ) return value_dict elif origin in (List, list) and len(args) == 1 and can_construct_from_params(args[0]): value_cls = annotation.__args__[0] value_list = [] for i, value_params in enumerate(popped_params): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_list.append(value) return value_list elif origin in (Tuple, tuple) and all(can_construct_from_params(arg) for arg in args): value_list = [] for i, (value_cls, value_params) in enumerate(zip(annotation.__args__, popped_params)): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_list.append(value) return tuple(value_list) elif origin in (Set, set) and len(args) == 1 and can_construct_from_params(args[0]): value_cls = annotation.__args__[0] value_set = set() for i, value_params in enumerate(popped_params): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_set.add(value) return value_set elif origin == Union: # Storing this so we can recover it later if we need to. backup_params = deepcopy(popped_params) # We'll try each of the given types in the union sequentially, returning the first one that # succeeds. for arg_annotation in args: try: return construct_arg( str(arg_annotation), argument_name, popped_params, arg_annotation, default, **extras, ) except (ValueError, TypeError, ConfigurationError, AttributeError): # Our attempt to construct the argument may have modified popped_params, so we # restore it here. popped_params = deepcopy(backup_params) # If none of them succeeded, we crash. raise ConfigurationError( f"Failed to construct argument {argument_name} with type {annotation}" ) elif origin == Lazy: if popped_params is default: return Lazy(lambda **kwargs: default) value_cls = args[0] subextras = create_extras(value_cls, extras) def constructor(**kwargs): return value_cls.from_params(params=popped_params, **kwargs, **subextras) return Lazy(constructor) # type: ignore else: # Pass it on as is and hope for the best. ¯\_(ツ)_/¯ if isinstance(popped_params, Params): return popped_params.as_dict(quiet=True) return popped_params
def construct_arg( class_name: str, argument_name: str, popped_params: Params, annotation: Type, default: Any, could_be_step: bool = True, **extras, ) -> Any: """ The first two parameters here are only used for logging if we encounter an error. """ from allennlp.tango.step import Step, _RefStep if could_be_step: # We try parsing as a step _first_. Parsing as a non-step always succeeds, because # it will fall back to returning a dict. So we can't try parsing as a non-step first. backup_params = deepcopy(popped_params) try: return construct_arg( class_name, argument_name, popped_params, Step[annotation], # type: ignore default, could_be_step=False, **extras, ) except (ValueError, TypeError, ConfigurationError, AttributeError): popped_params = backup_params origin = getattr(annotation, "__origin__", None) args = getattr(annotation, "__args__", []) # The parameter is optional if its default value is not the "no default" sentinel. optional = default != _NO_DEFAULT if hasattr(annotation, "from_params"): if popped_params is default: return default elif popped_params is not None: # Our params have an entry for this, so we use that. subextras = create_extras(annotation, extras) # In some cases we allow a string instead of a param dict, so # we need to handle that case separately. if isinstance(popped_params, str): if origin != Step: # We don't allow single strings to be upgraded to steps. # Since we try everything as a step first, upgrading strings to # steps automatically would cause confusion every time a step # name conflicts with any string anywhere in a config. popped_params = Params({"type": popped_params}) elif isinstance(popped_params, dict): popped_params = Params(popped_params) result = annotation.from_params(params=popped_params, **subextras) if isinstance(result, Step): if isinstance(result, _RefStep): existing_steps: Dict[str, Step] = extras.get( "existing_steps", {}) try: result = existing_steps[result.ref()] except KeyError: raise _RefStep.MissingStepError(result.ref()) expected_return_type = args[0] return_type = inspect.signature(result.run).return_annotation if return_type == inspect.Signature.empty: logger.warning( "Step %s has no return type annotation. Those are really helpful when " "debugging, so we recommend them highly.", result.__class__.__name__, ) elif not issubclass(return_type, expected_return_type): raise ConfigurationError( f"Step {result.name} returns {return_type}, but " f"we expected {expected_return_type}.") return result elif not optional: # Not optional and not supplied, that's an error! raise ConfigurationError( f"expected key {argument_name} for {class_name}") else: return default # If the parameter type is a Python primitive, just pop it off # using the correct casting pop_xyz operation. elif annotation in {int, bool}: if type(popped_params) in {int, bool}: return annotation(popped_params) else: raise TypeError( f"Expected {argument_name} to be a {annotation.__name__}.") elif annotation == str: # Strings are special because we allow casting from Path to str. if type(popped_params) == str or isinstance(popped_params, Path): return str(popped_params) # type: ignore else: raise TypeError(f"Expected {argument_name} to be a string.") elif annotation == float: # Floats are special because in Python, you can put an int wherever you can put a float. # https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html if type(popped_params) in {int, float}: return popped_params else: raise TypeError(f"Expected {argument_name} to be numeric.") # This is special logic for handling types like Dict[str, TokenIndexer], # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer], # which it creates by instantiating each value from_params and returning the resulting structure. elif (origin in {collections.abc.Mapping, Mapping, Dict, dict} and len(args) == 2 and can_construct_from_params(args[-1])): value_cls = annotation.__args__[-1] value_dict = {} if not isinstance(popped_params, Mapping): raise TypeError( f"Expected {argument_name} to be a Mapping (probably a dict or a Params object)." ) for key, value_params in popped_params.items(): value_dict[key] = construct_arg( str(value_cls), argument_name + "." + key, value_params, value_cls, _NO_DEFAULT, **extras, ) return value_dict elif origin in (Tuple, tuple) and all( can_construct_from_params(arg) for arg in args): value_list = [] for i, (value_cls, value_params) in enumerate( zip(annotation.__args__, popped_params)): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_list.append(value) return tuple(value_list) elif origin in (Set, set) and len(args) == 1 and can_construct_from_params( args[0]): value_cls = annotation.__args__[0] value_set = set() for i, value_params in enumerate(popped_params): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_set.add(value) return value_set elif origin == Union: # Storing this so we can recover it later if we need to. backup_params = deepcopy(popped_params) # We'll try each of the given types in the union sequentially, returning the first one that # succeeds. error_chain: Optional[Exception] = None for arg_annotation in args: try: return construct_arg( str(arg_annotation), argument_name, popped_params, arg_annotation, default, **extras, ) except (ValueError, TypeError, ConfigurationError, AttributeError) as e: # Our attempt to construct the argument may have modified popped_params, so we # restore it here. popped_params = deepcopy(backup_params) e.args = ( f"While constructing an argument of type {arg_annotation}", ) + e.args e.__cause__ = error_chain error_chain = e # If none of them succeeded, we crash. config_error = ConfigurationError( f"Failed to construct argument {argument_name} with type {annotation}." ) config_error.__cause__ = error_chain raise config_error elif origin == Lazy: if popped_params is default: return default value_cls = args[0] subextras = create_extras(value_cls, extras) return Lazy(value_cls, params=deepcopy(popped_params), constructor_extras=subextras) # type: ignore # For any other kind of iterable, we will just assume that a list is good enough, and treat # it the same as List. This condition needs to be at the end, so we don't catch other kinds # of Iterables with this branch. elif (origin in {collections.abc.Iterable, Iterable, List, list} and len(args) == 1 and can_construct_from_params(args[0])): value_cls = annotation.__args__[0] value_list = [] for i, value_params in enumerate(popped_params): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_list.append(value) return value_list else: # Pass it on as is and hope for the best. ¯\_(ツ)_/¯ if isinstance(popped_params, Params): return popped_params.as_dict() return popped_params