Exemple #1
0
 def merge(preferred_value: Any, fallback_value: Any) -> Any:
     if isinstance(preferred_value, dict) and isinstance(
             fallback_value, dict):
         return with_fallback(preferred_value, fallback_value)
     elif isinstance(preferred_value, dict) and isinstance(
             fallback_value, list):
         # treat preferred_value as a sparse list, where each key is an index to be overridden
         merged_list = fallback_value
         for elem_key, preferred_element in preferred_value.items():
             try:
                 index = int(elem_key)
                 merged_list[index] = merge(preferred_element,
                                            fallback_value[index])
             except ValueError:
                 raise ConfigurationError(
                     "could not merge dicts - the preferred dict contains "
                     f"invalid keys (key {elem_key} is not a valid list index)"
                 )
             except IndexError:
                 raise ConfigurationError(
                     "could not merge dicts - the preferred dict contains "
                     f"invalid keys (key {index} is out of bounds)")
         return merged_list
     else:
         return copy.deepcopy(preferred_value)
Exemple #2
0
    def __init__(self,
                 patience: Optional[int] = None,
                 metric_name: str = None,
                 should_decrease: bool = None) -> None:
        self._best_so_far: float = None
        self._patience = patience
        self._epochs_with_no_improvement = 0
        self._is_best_so_far = True
        self.best_epoch_metrics: Dict[str, float] = {}
        self._epoch_number = 0
        self.best_epoch: int = None

        # If the metric name starts with "+", we want it to increase.
        # If the metric name starts with "-", we want it to decrease.
        # We also allow you to not specify a metric name and just set `should_decrease` directly.
        if should_decrease is None and metric_name is None:
            raise ConfigurationError(
                "must specify either `should_decrease` or `metric_name` (but not both)"
            )
        elif should_decrease is not None and metric_name is not None:
            raise ConfigurationError(
                "must specify either `should_decrease` or `metric_name` (but not both)"
            )
        elif metric_name is not None:
            if metric_name[0] == "-":
                self._should_decrease = True
            elif metric_name[0] == "+":
                self._should_decrease = False
            else:
                raise ConfigurationError("metric_name must start with + or -")
        else:
            self._should_decrease = should_decrease
Exemple #3
0
def unflatten(flat_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Given a "flattened" dict with compound keys, e.g.
        {"a.b": 0}
    unflatten it:
        {"a": {"b": 0}}
    """
    unflat: Dict[str, Any] = {}

    for compound_key, value in flat_dict.items():
        curr_dict = unflat
        parts = compound_key.split(".")
        for key in parts[:-1]:
            curr_value = curr_dict.get(key)
            if key not in curr_dict:
                curr_dict[key] = {}
                curr_dict = curr_dict[key]
            elif isinstance(curr_value, dict):
                curr_dict = curr_value
            else:
                raise ConfigurationError("flattened dictionary is invalid")
        if not isinstance(curr_dict, dict) or parts[-1] in curr_dict:
            raise ConfigurationError("flattened dictionary is invalid")
        else:
            curr_dict[parts[-1]] = value

    return unflat
Exemple #4
0
    def __init__(self,
                 beta: float = 1.0,
                 average: str = None,
                 labels: List[int] = None) -> None:
        average_options = (None, 'micro', 'macro')
        if average not in average_options:
            raise ConfigurationError(
                f"`average` has to be one of {average_options}.")
        if beta <= 0:
            raise ConfigurationError(
                "`beta` should be >0 in the F-beta score.")

        self._beta = beta
        self._average = average
        self._labels = labels

        # statistics
        # the total number of true positive instances under each class
        # Shape: (num_classes, )
        self._true_positive_sum: Union[None, torch.Tensor] = None
        # the total number of instances
        # Shape: (num_classes, )
        self._total_sum: Union[None, torch.Tensor] = None
        # the total number of instances under each _predicted_ class,
        # including true positives and false positives
        # Shape: (num_classes, )
        self._pred_sum: Union[None, torch.Tensor] = None
        # the total number of instances under each _true_ class,
        # including true positives and false negatives
        # Shape: (num_classes, )
        self._true_sum: Union[None, torch.Tensor] = None
Exemple #5
0
    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None):
        """
        Parameters
        ----------
        predictions : ``torch.Tensor``, required.
            A one-dimensional tensor of prediction scores of shape (batch_size).
        gold_labels : ``torch.Tensor``, required.
            A one-dimensional label tensor of shape (batch_size), with {1, 0}
            entries for positive and negative class. If it's not binary,
            `positive_label` should be passed in the initialization.
        mask: ``torch.Tensor``, optional (default = None).
            A one-dimensional label tensor of shape (batch_size).
        """

        predictions, gold_labels, mask = self.unwrap_to_tensors(
            predictions, gold_labels, mask)

        # Sanity checks.
        if gold_labels.dim() != 1:
            raise ConfigurationError("gold_labels must be one-dimensional, "
                                     "but found tensor of shape: {}".format(
                                         gold_labels.size()))
        if predictions.dim() != 1:
            raise ConfigurationError("predictions must be one-dimensional, "
                                     "but found tensor of shape: {}".format(
                                         predictions.size()))

        unique_gold_labels = torch.unique(gold_labels)
        if unique_gold_labels.numel() > 2:
            raise ConfigurationError(
                "AUC can be used for binary tasks only. gold_labels has {} unique labels, "
                "expected at maximum 2.".format(unique_gold_labels.numel()))

        gold_labels_is_binary = set(unique_gold_labels.tolist()) <= {0, 1}
        if not gold_labels_is_binary and self._positive_label not in unique_gold_labels:
            raise ConfigurationError(
                "gold_labels should be binary with 0 and 1 or initialized positive_label "
                "{} should be present in gold_labels".format(
                    self._positive_label))

        if mask is None:
            batch_size = gold_labels.shape[0]
            mask = torch.ones(batch_size)
        mask = mask.byte()

        self._all_predictions = torch.cat([
            self._all_predictions,
            torch.masked_select(predictions, mask).float()
        ],
                                          dim=0)
        self._all_gold_labels = torch.cat([
            self._all_gold_labels,
            torch.masked_select(gold_labels, mask).long()
        ],
                                          dim=0)
Exemple #6
0
 def _get_instance_data(self) -> Iterator[Instance]:
     if self._input_file == "-":
         raise ConfigurationError(
             "stdin is not an option when using a DatasetReader.")
     elif self._dataset_reader is None:
         raise ConfigurationError(
             "To generate instances directly, pass a DatasetReader.")
     else:
         yield from self._dataset_reader.read(self._input_file)
Exemple #7
0
    def extract_module(self, path: str, freeze: bool = True) -> Module:
        """
        This method can be used to load a module from the pretrained model archive.

        It is also used implicitly in FromParams based construction. So instead of using standard
        params to construct a module, you can instead load a pretrained module from the model
        archive directly. For eg, instead of using params like {"type": "module_type", ...}, you
        can use the following template::

            {
                "_pretrained": {
                    "archive_file": "../path/to/model.tar.gz",
                    "path": "path.to.module.in.model",
                    "freeze": False
                }
            }

        If you use this feature with FromParams, take care of the following caveat: Call to
        initializer(self) at end of model initializer can potentially wipe the transferred parameters
        by reinitializing them. This can happen if you have setup initializer regex that also
        matches parameters of the transferred module. To safe-guard against this, you can either
        update your initializer regex to prevent conflicting match or add extra initializer::

            [
                [".*transferred_module_name.*", "prevent"]]
            ]

        Parameters
        ----------
        path : ``str``, required
            Path of target module to be loaded from the model.
            Eg. "_textfield_embedder.token_embedder_tokens"
        freeze : ``bool``, optional (default=True)
            Whether to freeze the module parameters or not.

        """
        modules_dict = {path: module for path, module in self.model.named_modules()}
        module = modules_dict.get(path, None)

        if not module:
            raise ConfigurationError(
                f"You asked to transfer module at path {path} from "
                f"the model {type(self.model)}. But it's not present."
            )
        if not isinstance(module, Module):
            raise ConfigurationError(
                f"The transferred object from model {type(self.model)} at path "
                f"{path} is not a PyTorch Module."
            )

        for parameter in module.parameters():  # type: ignore
            parameter.requires_grad_(not freeze)
        return module
Exemple #8
0
    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None,
                 end_index: int = sys.maxsize):
        """
        Parameters
        ----------
        predictions : ``torch.Tensor``, required.
            A tensor of predictions of shape (batch_size, k, sequence_length).
        gold_labels : ``torch.Tensor``, required.
            A tensor of integer class label of shape (batch_size, sequence_length).
        mask: ``torch.Tensor``, optional (default = None).
            A masking tensor the same size as ``gold_labels``.
        """
        predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask)

        # Some sanity checks.
        if gold_labels.dim() != predictions.dim() - 1:
            raise ConfigurationError("gold_labels must have dimension == predictions.dim() - 1 but "
                                     "found tensor of shape: {}".format(gold_labels.size()))
        if mask is not None and mask.size() != gold_labels.size():
            raise ConfigurationError("mask must have the same size as predictions but "
                                     "found tensor of shape: {}".format(mask.size()))

        batch_size = predictions.size()[0]
        correct = 0.0
        for i in range(batch_size):
            beams = predictions[i]
            cur_gold = gold_labels[i]

            if mask is not None:
                masked_gold = cur_gold * mask[i]
            else:
                masked_gold = cur_gold
            cleaned_gold = [x for x in masked_gold if x != 0 and x != end_index]

            retval = 0.
            for word in cleaned_gold:
                stillsearch = True
                for beam in beams:
                    # word is from cleaned gold which doesn't have 0 or
                    # end_index, so we don't need to explicitly remove those
                    # from beam.
                    if stillsearch and (word in beam):
                        retval += 1. / float(len(cleaned_gold))
                        stillsearch = False
            correct += retval

        self.correct_count += correct
        self.total_count += predictions.size()[0]
Exemple #9
0
    def pop(self,
            key: str,
            default: Any = DEFAULT,
            keep_as_dict: bool = False) -> Any:
        # pylint: disable=arguments-differ
        """
        Performs the functionality associated with dict.pop(key), along with checking for
        returned dictionaries, replacing them with Param objects with an updated history
        (unless keep_as_dict is True, in which case we leave them as dictionaries).

        If ``key`` is not present in the dictionary, and no default was specified, we raise a
        ``ConfigurationError``, instead of the typical ``KeyError``.
        """
        if default is self.DEFAULT:
            try:
                value = self.params.pop(key)
            except KeyError:
                raise ConfigurationError(
                    "key \"{}\" is required at location \"{}\"".format(
                        key, self.history))
        else:
            value = self.params.pop(key, default)

        if keep_as_dict or _is_dict_free(value):
            logger.info(self.history + key + " = " +
                        str(value))  # type: ignore
            return value
        else:
            return self._check_is_dict(key, value)
Exemple #10
0
    def print_statistics(self) -> None:
        # Make sure if has been indexed first
        sequence_field_lengths: Dict[str, List] = defaultdict(list)
        for instance in self.instances:
            if not instance.indexed:
                raise ConfigurationError(
                    "Instances must be indexed with vocabulary "
                    "before asking to print dataset statistics."
                )
            for field, field_padding_lengths in instance.get_padding_lengths().items():
                for key, value in field_padding_lengths.items():
                    sequence_field_lengths[f"{field}.{key}"].append(value)

        print("\n\n----Dataset Statistics----\n")
        for name, lengths in sequence_field_lengths.items():
            print(f"Statistics for {name}:")
            print(
                f"\tLengths: Mean: {numpy.mean(lengths)}, Standard Dev: {numpy.std(lengths)}, "
                f"Max: {numpy.max(lengths)}, Min: {numpy.min(lengths)}"
            )

        print("\n10 Random instances: ")
        for i in list(numpy.random.randint(len(self.instances), size=10)):
            print(f"Instance {i}:")
            print(f"\t{self.instances[i]}")
Exemple #11
0
    def pop_choice(self,
                   key: str,
                   choices: List[Any],
                   default_to_first_choice: bool = False) -> Any:
        """
        Gets the value of ``key`` in the ``params`` dictionary, ensuring that the value is one of
        the given choices. Note that this `pops` the key from params, modifying the dictionary,
        consistent with how parameters are processed in this codebase.

        Parameters
        ----------
        key: str
            Key to get the value from in the param dictionary
        choices: List[Any]
            A list of valid options for values corresponding to ``key``.  For example, if you're
            specifying the type of encoder to use for some part of your model, the choices might be
            the list of encoder classes we know about and can instantiate.  If the value we find in
            the param dictionary is not in ``choices``, we raise a ``ConfigurationError``, because
            the user specified an invalid value in their parameter file.
        default_to_first_choice: bool, optional (default=False)
            If this is ``True``, we allow the ``key`` to not be present in the parameter
            dictionary.  If the key is not present, we will use the return as the value the first
            choice in the ``choices`` list.  If this is ``False``, we raise a
            ``ConfigurationError``, because specifying the ``key`` is required (e.g., you `have` to
            specify your model class when running an experiment, but you can feel free to use
            default settings for encoders if you want).
        """
        default = choices[0] if default_to_first_choice else self.DEFAULT
        value = self.pop(key, default)
        if value not in choices:
            key_str = self.history + key
            message = '%s not in acceptable choices for %s: %s' % (
                value, key_str, str(choices))
            raise ConfigurationError(message)
        return value
Exemple #12
0
    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None):
        """
        Parameters
        ----------
        predictions : ``torch.Tensor``, required.
            A tensor of predictions of shape (batch_size, k, sequence_length).
        gold_labels : ``torch.Tensor``, required.
            A tensor of integer class label of shape (batch_size, sequence_length).
        mask: ``torch.Tensor``, optional (default = None).
            A masking tensor the same size as ``gold_labels``.
        """
        predictions, gold_labels, mask = self.unwrap_to_tensors(
            predictions, gold_labels, mask)

        # Some sanity checks.
        if gold_labels.dim() != predictions.dim() - 1:
            raise ConfigurationError(
                "gold_labels must have dimension == predictions.dim() - 1 but "
                "found tensor of shape: {}".format(gold_labels.size()))
        if mask is not None and mask.size() != gold_labels.size():
            raise ConfigurationError(
                "mask must have the same size as predictions but "
                "found tensor of shape: {}".format(mask.size()))

        k = predictions.size()[1]
        expanded_size = list(gold_labels.size())
        expanded_size.insert(1, k)
        expanded_gold = gold_labels.unsqueeze(1).expand(expanded_size)

        if mask is not None:
            expanded_mask = mask.unsqueeze(1).expand(expanded_size)
            masked_gold = expanded_mask * expanded_gold
            masked_predictions = expanded_mask * predictions
        else:
            masked_gold = expanded_gold
            masked_predictions = predictions

        eqs = masked_gold.eq(masked_predictions)
        matches_per_question = eqs.min(dim=2)[0]
        some_match = matches_per_question.max(dim=1)[0]
        correct = some_match.sum().item()

        self.total_count += predictions.size()[0]
        self.correct_count += correct
Exemple #13
0
    def __call__(self, predicted_trees: List[Tree],
                 gold_trees: List[Tree]) -> None:  # type: ignore
        """
        Parameters
        ----------
        predicted_trees : ``List[Tree]``
            A list of predicted NLTK Trees to compute score for.
        gold_trees : ``List[Tree]``
            A list of gold NLTK Trees to use as a reference.
        """
        if not os.path.exists(self._evalb_program_path):
            logger.warning(
                f"EVALB not found at {self._evalb_program_path}.  Attempting to compile it."
            )
            EvalbBracketingScorer.compile_evalb(self._evalb_directory_path)

            # If EVALB executable still doesn't exist, raise an error.
            if not os.path.exists(self._evalb_program_path):
                compile_command = (
                    f"python -c 'from reclib.training.metrics import EvalbBracketingScorer; "
                    f"EvalbBracketingScorer.compile_evalb(\"{self._evalb_directory_path}\")'"
                )
                raise ConfigurationError(
                    f"EVALB still not found at {self._evalb_program_path}. "
                    "You must compile the EVALB scorer before using it."
                    " Run 'make' in the '{}' directory or run: {}".format(
                        self._evalb_program_path, compile_command))
        tempdir = tempfile.mkdtemp()
        gold_path = os.path.join(tempdir, "gold.txt")
        predicted_path = os.path.join(tempdir, "predicted.txt")
        with open(gold_path, "w") as gold_file:
            for tree in gold_trees:
                gold_file.write(f"{tree.pformat(margin=1000000)}\n")

        with open(predicted_path, "w") as predicted_file:
            for tree in predicted_trees:
                predicted_file.write(f"{tree.pformat(margin=1000000)}\n")

        command = [
            self._evalb_program_path, "-p", self._evalb_param_path, gold_path,
            predicted_path
        ]
        completed_process = subprocess.run(command,
                                           stdout=subprocess.PIPE,
                                           universal_newlines=True,
                                           check=True)

        for line in completed_process.stdout.split("\n"):
            stripped = line.strip().split()
            if len(stripped) == 12 and stripped != self._header_line:
                # This line contains results for a single tree.
                numeric_line = [float(x) for x in stripped]
                self._correct_predicted_brackets += numeric_line[5]
                self._gold_brackets += numeric_line[6]
                self._predicted_brackets += numeric_line[7]

        shutil.rmtree(tempdir)
Exemple #14
0
 def _check_types(self) -> None:
     """
     Check that all the instances have the same types.
     """
     all_instance_fields_and_types: List[Dict[str, str]] = [
         {k: v.__class__.__name__ for k, v in x.fields.items()} for x in self.instances
     ]
     # Check all the field names and Field types are the same for every instance.
     if not all([all_instance_fields_and_types[0] == x for x in all_instance_fields_and_types]):
         raise ConfigurationError("You cannot construct a Batch with non-homogeneous Instances.")
Exemple #15
0
 def assert_empty(self, class_name: str):
     """
     Raises a ``ConfigurationError`` if ``self.params`` is not empty.  We take ``class_name`` as
     an argument so that the error message gives some idea of where an error happened, if there
     was one.  ``class_name`` should be the name of the `calling` class, the one that got extra
     parameters (if there are any).
     """
     if self.params:
         raise ConfigurationError(
             "Extra parameters passed to {}: {}".format(
                 class_name, self.params))
Exemple #16
0
    def list_available(cls) -> List[str]:
        """List default first if it exists"""
        keys = list(Registrable._registry[cls].keys())
        default = cls.default_implementation

        if default is None:
            return keys
        elif default not in keys:
            message = "Default implementation %s is not registered" % default
            raise ConfigurationError(message)
        else:
            return [default] + [k for k in keys if k != default]
Exemple #17
0
def takes_arg(obj, arg: str) -> bool:
    """
    Checks whether the provided obj takes a certain arg.
    If it's a class, we're really checking whether its constructor does.
    If it's a function or method, we're checking the object itself.
    Otherwise, we raise an error.
    """
    if inspect.isclass(obj):
        signature = inspect.signature(obj.__init__)
    elif inspect.ismethod(obj) or inspect.isfunction(obj):
        signature = inspect.signature(obj)
    else:
        raise ConfigurationError(f"object {obj} is not callable")
    return arg in signature.parameters
Exemple #18
0
 def add_subclass_to_registry(subclass: Type[T]):
     # Add to registry, raise an error if key has already been used.
     if name in registry:
         if exist_ok:
             message = (
                 f"{name} has already been registered as {registry[name].__name__}, but "
                 f"exist_ok=True, so overwriting with {cls.__name__}")
             logger.info(message)
         else:
             message = (
                 f"Cannot register {name} as {cls.__name__}; "
                 f"name already in use for {registry[name].__name__}")
             raise ConfigurationError(message)
     registry[name] = subclass
     return subclass
Exemple #19
0
def takes_kwargs(obj) -> bool:
    """
    Checks whether a provided object takes in any positional arguments.
    Similar to takes_arg, we do this for both the __init__ function of
    the class or a function / method
    Otherwise, we raise an error
    """
    if inspect.isclass(obj):
        signature = inspect.signature(obj.__init__)
    elif inspect.ismethod(obj) or inspect.isfunction(obj):
        signature = inspect.signature(obj)
    else:
        raise ConfigurationError(f"object {obj} is not callable")
    return bool(any([p.kind == inspect.Parameter.VAR_KEYWORD  # type: ignore
                     for p in signature.parameters.values()]))
Exemple #20
0
 def get(self, key: str, default: Any = DEFAULT):
     """
     Performs the functionality associated with dict.get(key) but also checks for returned
     dicts and returns a Params object in their place with an updated history.
     """
     if default is self.DEFAULT:
         try:
             value = self.params.get(key)
         except KeyError:
             raise ConfigurationError(
                 "key \"{}\" is required at location \"{}\"".format(
                     key, self.history))
     else:
         value = self.params.get(key, default)
     return self._check_is_dict(key, value)
Exemple #21
0
    def __init__(self,
                 patience: int = None,
                 validation_metric: str = "-loss") -> None:
        if patience is not None and (not isinstance(patience, int)
                                     or patience <= 0):
            raise ConfigurationError(
                f"patience must be a positive number, but got {patience}."
                f"To disable early stopping, don't specify it.")

        self.patience = patience
        self.validation_metric = validation_metric[1:]
        self.metric_tracker = MetricTracker(patience, validation_metric)
        self.starting_epoch = 0

        self.peak_cpu_usage = 0.0
        # Track pairs (gpu_id, memory usage)
        self.gpu_usage: List[Tuple[int, int]] = []
Exemple #22
0
    def restore_checkpoint(self, trainer: 'CallbackTrainer'):
        # Restores the model and training state from the last saved checkpoint.
        # This includes an epoch count and optimizer state, which is serialized separately
        # from model parameters. This function should only be used to continue training -
        # if you wish to load a model for inference/load parts of a model into a new
        # computation graph, you should use the native Pytorch functions:
        # `` model.load_state_dict(torch.load("/path/to/model/weights.th"))``

        # If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
        # this will do nothing.
        try:
            model_state, training_state = self.checkpointer.restore_checkpoint(
            )
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  "
                "Did you mean to output to a different serialization directory "
                "or delete the existing serialization directory?")

        if not training_state:
            # No checkpoint to restore, start at 0
            trainer.epoch_number = 0
            return

        trainer.model.load_state_dict(model_state)

        # Restore state_dict attrs
        for attr in self.state_dict_attrs:
            state_attr = getattr(trainer, attr)
            if state_attr is not None:
                state_attr.load_state_dict(training_state[attr])

        # Restore other attrs
        for attr in self.other_attrs:
            setattr(trainer, attr, training_state[attr])

        # Restore callback attrs
        for callback in trainer.handler.callbacks():
            callback.restore_training_state(training_state)

        if isinstance(training_state["epoch"], int):
            trainer.epoch_number = training_state["epoch"] + 1
        else:
            trainer.epoch_number = int(
                training_state["epoch"].split('.')[0]) + 1
Exemple #23
0
    def __init__(self,
                 serialization_dir: str,
                 cuda_device: Union[int, List] = -1) -> None:
        check_for_gpu(cuda_device)

        self._serialization_dir = serialization_dir

        # Configure GPUs:
        if not isinstance(cuda_device, int) and not isinstance(
                cuda_device, list):
            raise ConfigurationError(
                "Expected an int or list for cuda_device, got {}".format(
                    cuda_device))

        if isinstance(cuda_device, list):
            self._multiple_gpu = True
            self._cuda_devices = cuda_device
        else:
            self._multiple_gpu = False
            self._cuda_devices = [cuda_device]
Exemple #24
0
 def by_name(cls: Type[T], name: str) -> Type[T]:
     logger.info(f"instantiating registered subclass {name} of {cls}")
     if name not in Registrable._registry[cls]:
         raise ConfigurationError("%s is not a registered name for %s" %
                                  (name, cls.__name__))
     return Registrable._registry[cls].get(name)
Exemple #25
0
    def from_params(params: Params,
                    serialization_dir: str,
                    recover: bool = False,
                    cache_directory: str = None,
                    cache_prefix: str = None) -> 'TrainerPieces':
        all_datasets = training_util.datasets_from_params(
            params, cache_directory, cache_prefix)
        datasets_for_vocab_creation = set(
            params.pop("datasets_for_vocab_creation", all_datasets))

        for dataset in datasets_for_vocab_creation:
            if dataset not in all_datasets:
                raise ConfigurationError(
                    f"invalid 'dataset_for_vocab_creation' {dataset}")

        logger.info(
            "From dataset instances, %s will be considered for vocabulary creation.",
            ", ".join(datasets_for_vocab_creation))

        if recover and os.path.exists(
                os.path.join(serialization_dir, "vocabulary")):
            vocab = Vocabulary.from_files(
                os.path.join(serialization_dir, "vocabulary"))
            params.pop("vocabulary", {})
        else:
            vocab = Vocabulary.from_params(params.pop(
                "vocabulary", {}), (instance
                                    for key, dataset in all_datasets.items()
                                    if key in datasets_for_vocab_creation
                                    for instance in dataset))

        model = Model.from_params(vocab=vocab, params=params.pop('model'))

        # If vocab extension is ON for training, embedding extension should also be
        # done. If vocab and embeddings are already in sync, it would be a no-op.
        model.extend_embedder_vocab()

        # Initializing the model can have side effect of expanding the vocabulary
        vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))

        iterator = DataIterator.from_params(params.pop("iterator"))
        iterator.index_with(model.vocab)
        validation_iterator_params = params.pop("validation_iterator", None)
        if validation_iterator_params:
            validation_iterator = DataIterator.from_params(
                validation_iterator_params)
            validation_iterator.index_with(model.vocab)
        else:
            validation_iterator = None

        train_data = all_datasets['train']
        validation_data = all_datasets.get('validation')
        test_data = all_datasets.get('test')

        trainer_params = params.pop("trainer")
        no_grad_regexes = trainer_params.pop("no_grad", ())
        for name, parameter in model.named_parameters():
            if any(re.search(regex, name) for regex in no_grad_regexes):
                parameter.requires_grad_(False)

        frozen_parameter_names, tunable_parameter_names = \
            get_frozen_and_tunable_parameter_names(model)
        logger.info("Following parameters are Frozen  (without gradient):")
        for name in frozen_parameter_names:
            logger.info(name)
        logger.info("Following parameters are Tunable (with gradient):")
        for name in tunable_parameter_names:
            logger.info(name)

        return TrainerPieces(model, iterator, train_data, validation_data,
                             test_data, validation_iterator, trainer_params)
Exemple #26
0
    def from_params(
            cls,  # type: ignore
            model: Model,
            serialization_dir: str,
            iterator: DataIterator,
            train_data: Iterable[Instance],
            validation_data: Optional[Iterable[Instance]],
            params: Params,
            validation_iterator: DataIterator = None) -> 'Trainer':
        # pylint: disable=arguments-differ
        patience = params.pop_int("patience", None)
        validation_metric = params.pop("validation_metric", "-loss")
        shuffle = params.pop_bool("shuffle", True)
        num_epochs = params.pop_int("num_epochs", 20)
        cuda_device = parse_cuda_device(params.pop("cuda_device", -1))
        grad_norm = params.pop_float("grad_norm", None)
        grad_clipping = params.pop_float("grad_clipping", None)
        lr_scheduler_params = params.pop("learning_rate_scheduler", None)
        momentum_scheduler_params = params.pop("momentum_scheduler", None)

        if isinstance(cuda_device, list):
            model_device = cuda_device[0]
        else:
            model_device = cuda_device
        if model_device >= 0:
            # Moving model to GPU here so that the optimizer state gets constructed on
            # the right device.
            model = model.cuda(model_device)

        parameters = [[n, p] for n, p in model.named_parameters()
                      if p.requires_grad]
        optimizer = Optimizer.from_params(parameters, params.pop("optimizer"))
        if "moving_average" in params:
            moving_average = MovingAverage.from_params(
                params.pop("moving_average"), parameters=parameters)
        else:
            moving_average = None

        if lr_scheduler_params:
            lr_scheduler = LearningRateScheduler.from_params(
                optimizer, lr_scheduler_params)
        else:
            lr_scheduler = None
        if momentum_scheduler_params:
            momentum_scheduler = MomentumScheduler.from_params(
                optimizer, momentum_scheduler_params)
        else:
            momentum_scheduler = None

        if 'checkpointer' in params:
            if 'keep_serialized_model_every_num_seconds' in params or \
                    'num_serialized_models_to_keep' in params:
                raise ConfigurationError(
                    "Checkpointer may be initialized either from the 'checkpointer' key or from the "
                    "keys 'num_serialized_models_to_keep' and 'keep_serialized_model_every_num_seconds'"
                    " but the passed config uses both methods.")
            checkpointer = Checkpointer.from_params(params.pop("checkpointer"))
        else:
            num_serialized_models_to_keep = params.pop_int(
                "num_serialized_models_to_keep", 20)
            keep_serialized_model_every_num_seconds = params.pop_int(
                "keep_serialized_model_every_num_seconds", None)
            checkpointer = Checkpointer(
                serialization_dir=serialization_dir,
                num_serialized_models_to_keep=num_serialized_models_to_keep,
                keep_serialized_model_every_num_seconds=
                keep_serialized_model_every_num_seconds)
        model_save_interval = params.pop_float("model_save_interval", None)
        summary_interval = params.pop_int("summary_interval", 100)
        histogram_interval = params.pop_int("histogram_interval", None)
        should_log_parameter_statistics = params.pop_bool(
            "should_log_parameter_statistics", True)
        should_log_learning_rate = params.pop_bool("should_log_learning_rate",
                                                   False)
        log_batch_size_period = params.pop_int("log_batch_size_period", None)

        params.assert_empty(cls.__name__)
        return cls(
            model,
            optimizer,
            iterator,
            train_data,
            validation_data,
            patience=patience,
            validation_metric=validation_metric,
            validation_iterator=validation_iterator,
            shuffle=shuffle,
            num_epochs=num_epochs,
            serialization_dir=serialization_dir,
            cuda_device=cuda_device,
            grad_norm=grad_norm,
            grad_clipping=grad_clipping,
            learning_rate_scheduler=lr_scheduler,
            momentum_scheduler=momentum_scheduler,
            checkpointer=checkpointer,
            model_save_interval=model_save_interval,
            summary_interval=summary_interval,
            histogram_interval=histogram_interval,
            should_log_parameter_statistics=should_log_parameter_statistics,
            should_log_learning_rate=should_log_learning_rate,
            log_batch_size_period=log_batch_size_period,
            moving_average=moving_average)
Exemple #27
0
    def train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?")

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        train_metrics: Dict[str, float] = {}
        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        metrics['best_epoch'] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value

        for epoch in range(epoch_counter, self._num_epochs):
            epoch_start_time = time.time()
            train_metrics = self._train_epoch(epoch)

            # get peak of memory usage
            if 'cpu_memory_MB' in train_metrics:
                metrics['peak_cpu_memory_MB'] = max(
                    metrics.get('peak_cpu_memory_MB', 0),
                    train_metrics['cpu_memory_MB'])
            for key, value in train_metrics.items():
                if key.startswith('gpu_'):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0),
                                                 value)

            if self._validation_data is not None:
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, num_batches = self._validation_loss()
                    val_metrics = training_util.get_metrics(self.model,
                                                            val_loss,
                                                            num_batches,
                                                            reset=True)

                    # Check validation metric for early stopping
                    this_epoch_val_metric = val_metrics[
                        self._validation_metric]
                    self._metric_tracker.add_metric(this_epoch_val_metric)

                    if self._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            self._tensorboard.log_metrics(
                train_metrics,
                val_metrics=val_metrics,
                log_to_console=True,
                epoch=epoch + 1)  # +1 because tensorboard doesn't like 0

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = str(
                datetime.timedelta(seconds=training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics['best_epoch'] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = val_metrics

            if self._serialization_dir:
                dump_metrics(
                    os.path.join(self._serialization_dir,
                                 f'metrics_epoch_{epoch}.json'), metrics)

            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step(this_epoch_val_metric,
                                                   epoch)
            if self._momentum_scheduler:
                self._momentum_scheduler.step(this_epoch_val_metric, epoch)

            self._save_checkpoint(epoch)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s",
                        datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * \
                                           ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)

            epochs_trained += 1

        # make sure pending events are flushed to disk and files are closed properly
        self._tensorboard.close()

        # Load the best model state before returning
        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics
Exemple #28
0
    def __init__(
            self,
            model: Model,
            optimizer: torch.optim.Optimizer,
            iterator: DataIterator,
            train_dataset: Iterable[Instance],
            validation_dataset: Optional[Iterable[Instance]] = None,
            patience: Optional[int] = None,
            validation_metric: str = "-loss",
            validation_iterator: DataIterator = None,
            shuffle: bool = True,
            num_epochs: int = 20,
            serialization_dir: Optional[str] = None,
            num_serialized_models_to_keep: int = 20,
            keep_serialized_model_every_num_seconds: int = None,
            checkpointer: Checkpointer = None,
            model_save_interval: float = None,
            cuda_device: Union[int, List] = -1,
            grad_norm: Optional[float] = None,
            grad_clipping: Optional[float] = None,
            learning_rate_scheduler: Optional[LearningRateScheduler] = None,
            momentum_scheduler: Optional[MomentumScheduler] = None,
            summary_interval: int = 100,
            histogram_interval: int = None,
            should_log_parameter_statistics: bool = True,
            should_log_learning_rate: bool = False,
            log_batch_size_period: Optional[int] = None,
            moving_average: Optional[MovingAverage] = None) -> None:
        """
        A trainer for doing supervised learning. It just takes a labeled dataset
        and a ``DataIterator``, and uses the supplied ``Optimizer`` to learn the weights
        for your model over some fixed number of epochs. You can also pass in a validation
        dataset and enable early stopping. There are many other bells and whistles as well.

        Parameters
        ----------
        model : ``Model``, required.
            An reclib model to be optimized. Pytorch Modules can also be optimized if
            their ``forward`` method returns a dictionary with a "loss" key, containing a
            scalar tensor representing the loss function to be optimized.

            If you are training your model using GPUs, your model should already be
            on the correct device. (If you use `Trainer.from_params` this will be
            handled for you.)
        optimizer : ``torch.nn.Optimizer``, required.
            An instance of a Pytorch Optimizer, instantiated with the parameters of the
            model to be optimized.
        iterator : ``DataIterator``, required.
            A method for iterating over a ``Dataset``, yielding padded indexed batches.
        train_dataset : ``Dataset``, required.
            A ``Dataset`` to train on. The dataset should have already been indexed.
        validation_dataset : ``Dataset``, optional, (default = None).
            A ``Dataset`` to evaluate on. The dataset should have already been indexed.
        patience : Optional[int] > 0, optional (default=None)
            Number of epochs to be patient before early stopping: the training is stopped
            after ``patience`` epochs with no improvement. If given, it must be ``> 0``.
            If None, early stopping is disabled.
        validation_metric : str, optional (default="loss")
            Validation metric to measure for whether to stop training using patience
            and whether to serialize an ``is_best`` model each epoch. The metric name
            must be prepended with either "+" or "-", which specifies whether the metric
            is an increasing or decreasing function.
        validation_iterator : ``DataIterator``, optional (default=None)
            An iterator to use for the validation set.  If ``None``, then
            use the training `iterator`.
        shuffle: ``bool``, optional (default=True)
            Whether to shuffle the instances in the iterator or not.
        num_epochs : int, optional (default = 20)
            Number of training epochs.
        serialization_dir : str, optional (default=None)
            Path to directory for saving and loading model files. Models will not be saved if
            this parameter is not passed.
        num_serialized_models_to_keep : ``int``, optional (default=20)
            Number of previous model checkpoints to retain.  Default is to keep 20 checkpoints.
            A value of None or -1 means all checkpoints will be kept.
        keep_serialized_model_every_num_seconds : ``int``, optional (default=None)
            If num_serialized_models_to_keep is not None, then occasionally it's useful to
            save models at a given interval in addition to the last num_serialized_models_to_keep.
            To do so, specify keep_serialized_model_every_num_seconds as the number of seconds
            between permanently saved checkpoints.  Note that this option is only used if
            num_serialized_models_to_keep is not None, otherwise all checkpoints are kept.
        checkpointer : ``Checkpointer``, optional (default=None)
            An instance of class Checkpointer to use instead of the default. If a checkpointer is specified,
            the arguments num_serialized_models_to_keep and keep_serialized_model_every_num_seconds should
            not be specified. The caller is responsible for initializing the checkpointer so that it is
            consistent with serialization_dir.
        model_save_interval : ``float``, optional (default=None)
            If provided, then serialize models every ``model_save_interval``
            seconds within single epochs.  In all cases, models are also saved
            at the end of every epoch if ``serialization_dir`` is provided.
        cuda_device : ``Union[int, List[int]]``, optional (default = -1)
            An integer or list of integers specifying the CUDA device(s) to use. If -1, the CPU is used.
        grad_norm : ``float``, optional, (default = None).
            If provided, gradient norms will be rescaled to have a maximum of this value.
        grad_clipping : ``float``, optional (default = ``None``).
            If provided, gradients will be clipped `during the backward pass` to have an (absolute)
            maximum of this value.  If you are getting ``NaNs`` in your gradients during training
            that are not solved by using ``grad_norm``, you may need this.
        learning_rate_scheduler : ``LearningRateScheduler``, optional (default = None)
            If specified, the learning rate will be decayed with respect to
            this schedule at the end of each epoch (or batch, if the scheduler implements
            the ``step_batch`` method). If you use :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`,
            this will use the ``validation_metric`` provided to determine if learning has plateaued.
            To support updating the learning rate on every batch, this can optionally implement
            ``step_batch(batch_num_total)`` which updates the learning rate given the batch number.
        momentum_scheduler : ``MomentumScheduler``, optional (default = None)
            If specified, the momentum will be updated at the end of each batch or epoch
            according to the schedule.
        summary_interval: ``int``, optional, (default = 100)
            Number of batches between logging scalars to tensorboard
        histogram_interval : ``int``, optional, (default = ``None``)
            If not None, then log histograms to tensorboard every ``histogram_interval`` batches.
            When this parameter is specified, the following additional logging is enabled:
                * Histograms of model parameters
                * The ratio of parameter update norm to parameter norm
                * Histogram of layer activations
            We log histograms of the parameters returned by
            ``model.get_parameters_for_histogram_tensorboard_logging``.
            The layer activations are logged for any modules in the ``Model`` that have
            the attribute ``should_log_activations`` set to ``True``.  Logging
            histograms requires a number of GPU-CPU copies during training and is typically
            slow, so we recommend logging histograms relatively infrequently.
            Note: only Modules that return tensors, tuples of tensors or dicts
            with tensors as values currently support activation logging.
        should_log_parameter_statistics : ``bool``, optional, (default = True)
            Whether to send parameter statistics (mean and standard deviation
            of parameters and gradients) to tensorboard.
        should_log_learning_rate : ``bool``, optional, (default = False)
            Whether to send parameter specific learning rate to tensorboard.
        log_batch_size_period : ``int``, optional, (default = ``None``)
            If defined, how often to log the average batch size.
        moving_average: ``MovingAverage``, optional, (default = None)
            If provided, we will maintain moving averages for all parameters. During training, we
            employ a shadow variable for each parameter, which maintains the moving average. During
            evaluation, we backup the original parameters and assign the moving averages to corresponding
            parameters. Be careful that when saving the checkpoint, we will save the moving averages of
            parameters. This is necessary because we want the saved model to perform as well as the validated
            model if we load it later. But this may cause problems if you restart the training from checkpoint.
        """
        super().__init__(serialization_dir, cuda_device)

        # I am not calling move_to_gpu here, because if the model is
        # not already on the GPU then the optimizer is going to be wrong.
        self.model = model

        self.iterator = iterator
        self._validation_iterator = validation_iterator
        self.shuffle = shuffle
        self.optimizer = optimizer
        self.train_data = train_dataset
        self._validation_data = validation_dataset

        if patience is None:  # no early stopping
            if validation_dataset:
                logger.warning(
                    'You provided a validation dataset but patience was set to None, '
                    'meaning that early stopping is disabled')
        elif (not isinstance(patience, int)) or patience <= 0:
            raise ConfigurationError(
                '{} is an invalid value for "patience": it must be a positive integer '
                'or None (if you want to disable early stopping)'.format(
                    patience))

        # For tracking is_best_so_far and should_stop_early
        self._metric_tracker = MetricTracker(patience, validation_metric)
        # Get rid of + or -
        self._validation_metric = validation_metric[1:]

        self._num_epochs = num_epochs

        if checkpointer is not None:
            # We can't easily check if these parameters were passed in, so check against their default values.
            # We don't check against serialization_dir since it is also used by the parent class.
            if num_serialized_models_to_keep != 20 or \
                    keep_serialized_model_every_num_seconds is not None:
                raise ConfigurationError(
                    "When passing a custom Checkpointer, you may not also pass in separate checkpointer "
                    "args 'num_serialized_models_to_keep' or 'keep_serialized_model_every_num_seconds'."
                )
            self._checkpointer = checkpointer
        else:
            self._checkpointer = Checkpointer(
                serialization_dir, keep_serialized_model_every_num_seconds,
                num_serialized_models_to_keep)

        self._model_save_interval = model_save_interval

        self._grad_norm = grad_norm
        self._grad_clipping = grad_clipping

        self._learning_rate_scheduler = learning_rate_scheduler
        self._momentum_scheduler = momentum_scheduler
        self._moving_average = moving_average

        # We keep the total batch number as an instance variable because it
        # is used inside a closure for the hook which logs activations in
        # ``_enable_activation_logging``.
        self._batch_num_total = 0

        self._tensorboard = TensorboardWriter(
            get_batch_num_total=lambda: self._batch_num_total,
            serialization_dir=serialization_dir,
            summary_interval=summary_interval,
            histogram_interval=histogram_interval,
            should_log_parameter_statistics=should_log_parameter_statistics,
            should_log_learning_rate=should_log_learning_rate)

        self._log_batch_size_period = log_batch_size_period

        self._last_log = 0.0  # time of last logging

        # Enable activation logging.
        if histogram_interval is not None:
            self._tensorboard.enable_activation_logging(self.model)
    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None,
                 prediction_map: Optional[torch.Tensor] = None):
        """
        Parameters
        ----------
        predictions : ``torch.Tensor``, required.
            A tensor of predictions of shape (batch_size, sequence_length, num_classes).
        gold_labels : ``torch.Tensor``, required.
            A tensor of integer class label of shape (batch_size, sequence_length). It must be the same
            shape as the ``predictions`` tensor without the ``num_classes`` dimension.
        mask: ``torch.Tensor``, optional (default = None).
            A masking tensor the same size as ``gold_labels``.
        prediction_map: ``torch.Tensor``, optional (default = None).
            A tensor of size (batch_size, num_classes) which provides a mapping from the index of predictions
            to the indices of the label vocabulary. If provided, the output label at each timestep will be
            ``vocabulary.get_index_to_token_vocabulary(prediction_map[batch, argmax(predictions[batch, t]))``,
            rather than simply ``vocabulary.get_index_to_token_vocabulary(argmax(predictions[batch, t]))``.
            This is useful in cases where each Instance in the dataset is associated with a different possible
            subset of labels from a large label-space (IE FrameNet, where each frame has a different set of
            possible roles associated with it).
        """
        if mask is None:
            mask = torch.ones_like(gold_labels)

        predictions, gold_labels, mask, prediction_map = self.unwrap_to_tensors(
            predictions, gold_labels, mask, prediction_map)

        num_classes = predictions.size(-1)
        if (gold_labels >= num_classes).any():
            raise ConfigurationError(
                "A gold label passed to SpanBasedF1Measure contains an "
                "id >= {}, the number of classes.".format(num_classes))

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        argmax_predictions = predictions.max(-1)[1]

        if prediction_map is not None:
            argmax_predictions = torch.gather(prediction_map, 1,
                                              argmax_predictions)
            gold_labels = torch.gather(prediction_map, 1, gold_labels.long())

        argmax_predictions = argmax_predictions.float()

        # Iterate over timesteps in batch.
        batch_size = gold_labels.size(0)
        for i in range(batch_size):
            sequence_prediction = argmax_predictions[i, :]
            sequence_gold_label = gold_labels[i, :]
            length = sequence_lengths[i]

            if length == 0:
                # It is possible to call this metric with sequences which are
                # completely padded. These contribute nothing, so we skip these rows.
                continue

            predicted_string_labels = [
                self._label_vocabulary[label_id]
                for label_id in sequence_prediction[:length].tolist()
            ]
            gold_string_labels = [
                self._label_vocabulary[label_id]
                for label_id in sequence_gold_label[:length].tolist()
            ]

            tags_to_spans_function = None
            # `label_encoding` is empty and `tags_to_spans_function` is provided.
            if self._label_encoding is None and self._tags_to_spans_function:
                tags_to_spans_function = self._tags_to_spans_function
            # Search by `label_encoding`.
            elif self._label_encoding == "BIO":
                tags_to_spans_function = bio_tags_to_spans
            elif self._label_encoding == "IOB1":
                tags_to_spans_function = iob1_tags_to_spans
            elif self._label_encoding == "BIOUL":
                tags_to_spans_function = bioul_tags_to_spans
            elif self._label_encoding == "BMES":
                tags_to_spans_function = bmes_tags_to_spans

            predicted_spans = tags_to_spans_function(predicted_string_labels,
                                                     self._ignore_classes)
            gold_spans = tags_to_spans_function(gold_string_labels,
                                                self._ignore_classes)

            predicted_spans = self._handle_continued_spans(predicted_spans)
            gold_spans = self._handle_continued_spans(gold_spans)

            for span in predicted_spans:
                if span in gold_spans:
                    self._true_positives[span[0]] += 1
                    gold_spans.remove(span)
                else:
                    self._false_positives[span[0]] += 1
            # These spans weren't predicted.
            for span in gold_spans:
                self._false_negatives[span[0]] += 1
    def __init__(
        self,
        vocabulary: Vocabulary,
        tag_namespace: str = "tags",
        ignore_classes: List[str] = None,
        label_encoding: Optional[str] = "BIO",
        tags_to_spans_function: Optional[TAGS_TO_SPANS_FUNCTION_TYPE] = None
    ) -> None:
        """
        Parameters
        ----------
        vocabulary : ``Vocabulary``, required.
            A vocabulary containing the tag namespace.
        tag_namespace : str, required.
            This metric assumes that a BIO format is used in which the
            labels are of the format: ["B-LABEL", "I-LABEL"].
        ignore_classes : List[str], optional.
            Span labels which will be ignored when computing span metrics.
            A "span label" is the part that comes after the BIO label, so it
            would be "ARG1" for the tag "B-ARG1". For example by passing:

             ``ignore_classes=["V"]``
            the following sequence would not consider the "V" span at index (2, 3)
            when computing the precision, recall and F1 metrics.

            ["O", "O", "B-V", "I-V", "B-ARG1", "I-ARG1"]

            This is helpful for instance, to avoid computing metrics for "V"
            spans in a BIO tagging scheme which are typically not included.
        label_encoding : ``str``, optional (default = "BIO")
            The encoding used to specify label span endpoints in the sequence.
            Valid options are "BIO", "IOB1", "BIOUL" or "BMES".
        tags_to_spans_function: ``Callable``, optional (default = ``None``)
            If ``label_encoding`` is ``None``, ``tags_to_spans_function`` will be
            used to generate spans.
        """
        if label_encoding and tags_to_spans_function:
            raise ConfigurationError(
                'Both label_encoding and tags_to_spans_function are provided. '
                'Set "label_encoding=None" explicitly to enable tags_to_spans_function.'
            )
        if label_encoding:
            if label_encoding not in ["BIO", "IOB1", "BIOUL", "BMES"]:
                raise ConfigurationError(
                    "Unknown label encoding - expected 'BIO', 'IOB1', 'BIOUL', 'BMES'."
                )
        elif tags_to_spans_function is None:
            raise ConfigurationError(
                'At least one of the (label_encoding, tags_to_spans_function) should be provided.'
            )

        self._label_encoding = label_encoding
        self._tags_to_spans_function = tags_to_spans_function
        self._label_vocabulary = vocabulary.get_index_to_token_vocabulary(
            tag_namespace)
        self._ignore_classes: List[str] = ignore_classes or []

        # These will hold per label span counts.
        self._true_positives: Dict[str, int] = defaultdict(int)
        self._false_positives: Dict[str, int] = defaultdict(int)
        self._false_negatives: Dict[str, int] = defaultdict(int)