class PnetTagger(Model): """ The ``PnetTagger`` is the tagger that is describled in the paper "Few-shot classification in Named Entity Recognition task" Parameters ---------- vocab : ``Vocabulary``, required A Vocabulary, required in order to compute sizes for input/output projections. text_field_embedder : ``TextFieldEmbedder``, required Used to embed the tokens ``TextField`` we get as input to the model. encoder : ``Seq2SeqEncoder`` The encoder that we will use in between embedding tokens and predicting output tags. label_namespace : ``str``, optional (default=``labels``) This is needed to compute the SpanBasedF1Measure metric. Unless you did something unusual, the default value should be what you want. dropout: ``float``, optional (detault=``None``) constraint_type : ``str``, optional (default=``None``) If provided, the CRF will be constrained at decoding time to produce valid labels based on the specified type (e.g. "BIO", or "BIOUL"). include_start_end_transitions : ``bool``, optional (default=``True``) Whether to include start and end transition parameters in the CRF. initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) Used to initialize the model parameters. regularizer : ``RegularizerApplicator``, optional (default=``None``) If provided, will be used to calculate the regularization penalty during training. """ def __init__( self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, label_namespace: str = "labels", constraint_type: str = None, include_start_end_transitions: bool = True, dropout: float = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, cuda_device: int = -1, ) -> None: super().__init__(vocab, regularizer) self.label_namespace = label_namespace self.text_field_embedder = text_field_embedder # This is our trainable parameter that is used as logit of 'O'-tag self.bias_outside = torch.nn.Parameter(torch.zeros(1) - 4.0, requires_grad=True) self.num_tags = self.vocab.get_vocab_size(label_namespace) # We also train scales in the embedding space for every class assuming that they may be different. self.scale_classes = torch.nn.Parameter(torch.ones(self.num_tags), requires_grad=True) self.encoder = encoder if dropout: self.dropout = torch.nn.Dropout(dropout) else: self.dropout = None self.last_layer = TimeDistributed( Linear(self.encoder.get_output_dim(), 64)) if constraint_type is not None: labels = self.vocab.get_index_to_token_vocabulary(label_namespace) constraints = allowed_transitions(constraint_type, labels) else: constraints = None self.crf = ConditionalRandomField( self.num_tags, constraints, include_start_end_transitions=include_start_end_transitions, ) self.loss = torch.nn.CrossEntropyLoss() self.cuda_device = cuda_device if self.cuda_device >= 0: self.text_field_embedder = self.text_field_embedder.cuda( self.cuda_device) self.encoder = self.encoder.cuda(self.cuda_device) self.last_layer = self.last_layer.cuda(self.cuda_device) self.elmo_weight = torch.nn.Parameter(torch.ones(1).cuda( self.cuda_device), requires_grad=True) self.span_metric = SpanBasedF1Measure( vocab, tag_namespace=label_namespace, label_encoding=constraint_type or "BIO", ) check_dimensions_match( text_field_embedder.get_output_dim(), encoder.get_input_dim(), "text field embedding dim", "encoder input dim", ) initializer(self) self.hash = 0 self.number_epoch = 0 @overrides def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], tags: torch.LongTensor = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : ``Dict[str, torch.LongTensor]``, required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. tags : ``torch.LongTensor``, optional (default = ``None``) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_tokens)``. Returns ------- An output dictionary consisting of: logits : ``torch.FloatTensor`` The logits that are the output of the ``tag_projection_layer`` mask : ``torch.LongTensor`` The text field mask for the input tokens tags : ``List[List[str]]`` The predicted tags using the Viterbi algorithm. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. Only computed if gold label ``tags`` are provided. """ if self.cuda_device >= 0: # Here we create some permanent GPU Variables because it's inefficient to create new GPU Variable every time if self.number_epoch == 0: self.tokens = tokens["tokens"].clone().cuda(self.cuda_device) self.token_characters = ( tokens["token_characters"].clone().cuda(self.cuda_device)) self.mask = (util.get_text_field_mask(tokens).clone().cuda( self.cuda_device)) self.elmo_tokens = tokens["elmo"].clone().cuda( self.cuda_device) else: self.tokens.data = tokens["tokens"].data.cuda(self.cuda_device) self.token_characters.data = tokens[ "token_characters"].data.cuda(self.cuda_device) self.mask.data = util.get_text_field_mask(tokens).data.cuda( self.cuda_device) self.elmo_tokens.data = tokens["elmo"].data.cuda( self.cuda_device) else: self.tokens = tokens["tokens"].clone() self.token_characters = tokens["token_characters"].clone() self.mask = util.get_text_field_mask(tokens).clone() self.elmo_tokens = tokens["elmo"].clone() # To prevent memory overflow we compute embeddings using internal minibatches number = 25 elmo_parts_input = [ self.elmo_tokens[(i * number):min(( (i + 1) * number), tokens["elmo"].data.shape[0])] for i in range(int(np.ceil(tokens["elmo"].data.shape[0] / number))) ] tokens_parts_input = [ self.tokens[(i * number):min(((i + 1) * number), tokens["elmo"].data.shape[0])] for i in range( int(np.ceil(tokens["tokens"].data.shape[0] / number))) ] chars_parts_input = [ self.token_characters[(i * number):min(( (i + 1) * number), tokens["elmo"].data.shape[0])] for i in range( int(np.ceil(tokens["token_characters"].data.shape[0] / number))) ] results = [ self.text_field_embedder({ "elmo": elmo_part, "tokens": tokens_part, "token_characters": chars_part, }) for elmo_part, tokens_part, chars_part in zip( elmo_parts_input, tokens_parts_input, chars_parts_input) ] # Clean memory del elmo_parts_input[:] del tokens_parts_input[:] del chars_parts_input[:] embedded_text_input = torch.cat(results, dim=0) del results[:] mask = util.get_text_field_mask(tokens) # Here we apply dropout to embeddings if self.dropout: dropped = self.dropout(embedded_text_input) # We again split our data to compute new hidden layer dropped_parts = [ dropped[(i * number):min(((i + 1) * number), tokens["elmo"].data.shape[0])] for i in range(int(np.ceil(tokens["elmo"].data.shape[0] / number))) ] del dropped mask_parts = [ self.mask[(i * number):min(((i + 1) * number), tokens["elmo"].data.shape[0])] for i in range(int(np.ceil(tokens["elmo"].data.shape[0] / number))) ] results = [ self.encoder(dropped_part, mask_part) for dropped_part, mask_part in zip(dropped_parts, mask_parts) ] del dropped_parts[:] del mask_parts[:] encoded_text = torch.cat(results, dim=0) # Again we apply dropout if self.dropout: encoded_text = self.dropout(encoded_text) # Apply the last layer embeddings = self.last_layer(encoded_text) # Here we split our batch into support and query sentences. # This division depends on what we do now: training or testing. # This happens because we generate train and test datasets using separate procedures. if embeddings.requires_grad: split_i = 40 else: split_i = 20 # Here we split all the data tags_support = tags[:split_i] tags_query = tags[split_i:] uniq_support = np.unique(tags_support.cpu().data.numpy()) support = embeddings[:split_i] query = embeddings[split_i:] support_mask = mask[:split_i] query_mask = mask[split_i:] # We will need numpy-masks mask_query = query_mask.data.cpu().numpy() mask_support = support_mask.data.cpu().numpy() # We want to map from tag numbers given by general dictionary to numbers inside this batch and vice versa. decoder = dict(zip(uniq_support, np.arange(uniq_support.shape[0]))) encoder = dict(zip(np.arange(uniq_support.shape[0]), uniq_support)) # Here we spread out our embeddings using tab labels embeds_per_class = [ [] for _ in np.arange(np.unique(uniq_support).shape[0]) ] tags_numpy = tags_support.data.cpu().numpy() for i_sen, sentence in enumerate(support): for i_word, word in enumerate(sentence): if mask_support[i_sen, i_word] == 1: tag = tags_numpy[i_sen][i_word] if tag > 0: embeds_per_class[decoder[tag]].append(word) # Here we compute embeddings prototypes = [ torch.zeros_like(embeds_per_class[1][0]) for _ in range(len(embeds_per_class)) ] for i in range(len(embeds_per_class)): for embed in embeds_per_class[i]: prototypes[i] += embed / len(embeds_per_class[i]) # We are going to compute logits for every class in data because we use constant-size CRF layer. # Logits are equal -100 by default because we want our objects to have 0-probabilities # for classes that are not used in this batch logits = (Variable( torch.zeros( (tags_query.shape[0], tags_query.shape[1], self.num_tags))) - 100.0) for i_sen, sentence in enumerate(query): for i_word, word in enumerate(sentence): if mask_query[i_sen, i_word] == 1: logits[i_sen, i_word, 0] = self.bias_outside for i_class in range(len(embeds_per_class))[1:]: distance = torch.sum( torch.pow(word - prototypes[i_class], 2)) logits[i_sen, i_word, encoder[i_class]] = ( -distance * self.scale_classes[encoder[i_class]]) # Compute prediction best_paths = self.crf.viterbi_tags(logits, query_mask) # Just get the tags and ignore the score. query_tags = [x for x, y in best_paths] output = {"mask": mask} # Use negative log-likelihood of the true tag sequence as loss. # we do the same things as we do in a basic CRF-tagger. log_likelihood = self.crf(logits.cuda(self.cuda_device), tags_query, query_mask) if embeddings.requires_grad: log_likelihood = log_likelihood.cuda(self.cuda_device) else: log_likelihood = log_likelihood.detach().cuda(self.cuda_device) output["loss"] = -log_likelihood # Compute one-hot answers to compute F1-metric of prediction. class_probabilities = logits * 0.0 for i, instance_tags in enumerate(query_tags): for j, tag_id in enumerate(instance_tags): class_probabilities[i, j, tag_id] = 1 self.span_metric(class_probabilities, tags_query, query_mask) self.number_epoch += 1 return output @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Converts the tag ids to the actual tags. ``output_dict["tags"]`` is a list of lists of tag_ids, so we use an ugly nested list comprehension. """ output_dict["tags"] = [[ self.vocab.get_token_from_index(tag, namespace="labels") for tag in instance_tags ] for instance_tags in output_dict["tags"]] return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metric_dict = self.span_metric.get_metric(reset=reset) return {x: y for x, y in metric_dict.items() if "overall" in x} @classmethod def from_params(cls, vocab: Vocabulary, params: Params) -> "PnetTagger": cuda_device = params.pop("cuda_device") embedder_params = params.pop("text_field_embedder") text_field_embedder = TextFieldEmbedder.from_params(embedder_params, vocab=vocab) encoder = Seq2SeqEncoder.from_params(params.pop("encoder")) label_namespace = params.pop("label_namespace", "labels") constraint_type = params.pop("constraint_type", None) dropout = params.pop("dropout", None) include_start_end_transitions = params.pop( "include_start_end_transitions", True) initializer = InitializerApplicator.from_params( params.pop("initializer", [])) regularizer = RegularizerApplicator.from_params( params.pop("regularizer", [])) params.assert_empty(cls.__name__) return cls( vocab=vocab, text_field_embedder=text_field_embedder, encoder=encoder, label_namespace=label_namespace, constraint_type=constraint_type, dropout=dropout, include_start_end_transitions=include_start_end_transitions, initializer=initializer, regularizer=regularizer, cuda_device=cuda_device, ) @overrides def load_state_dict(self, state_dict, strict=True): r"""Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Arguments: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` """ # Here we reset some parameters that we don't need to load from warming state_dict.pop("text_field_embedder.token_embedder_tokens.weight", None) state_dict.pop( "text_field_embedder.token_embedder_token_characters._embedding._module.weight", None, ) state_dict.pop("tag_projection_layer._module.weight", None) state_dict.pop("tag_projection_layer._module.bias", None) state_dict.pop("crf.transitions", None) state_dict.pop("crf._constraint_mask", None) missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=""): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(self) if len(error_msgs) > 0: raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( self.__class__.__name__, "\n\t".join(error_msgs)))