def split_by_doc(self) -> List[TransformerData]:
     """Split a TransformerData that represents a batch into a list with
     one TransformerData per Doc.
     """
     flat_spans = []
     for doc_spans in self.spans:
         flat_spans.extend(doc_spans)
     token_positions = get_token_positions(flat_spans)
     outputs = []
     start = 0
     prev_tokens = 0
     for doc_spans in self.spans:
         if len(doc_spans) == 0 or len(doc_spans[0]) == 0:
             outputs.append(TransformerData.empty())
             continue
         start_i = token_positions[doc_spans[0][0]]
         end_i = token_positions[doc_spans[-1][-1]] + 1
         end = start + len(doc_spans)
         doc_tokens = self.wordpieces[start:end]
         doc_align = self.align[start_i:end_i]
         doc_align.data = doc_align.data - prev_tokens
         model_output = ModelOutput()
         last_hidden_state = self.model_output.last_hidden_state
         for key, output in self.model_output.items():
             if isinstance(output, torch.Tensor):
                 model_output[key] = torch2xp(output[start:end])
             elif (isinstance(output, tuple)
                   and all(isinstance(t, torch.Tensor) for t in output)
                   and all(t.shape[0] == last_hidden_state.shape[0]
                           for t in output)):
                 model_output[key] = [
                     torch2xp(t[start:end]) for t in output
                 ]
         outputs.append(
             TransformerData(
                 wordpieces=doc_tokens,
                 model_output=model_output,
                 align=doc_align,
             ))
         prev_tokens += doc_tokens.input_ids.size
         start += len(doc_spans)
     return outputs
Example #2
0
 def backprop_trf_to_tensor(
         d_outputs: List[Floats2d]) -> List[TransformerData]:
     d_trf_datas = []
     zipped = zip(trf_datas, d_outputs, backprops)
     for trf_data, d_output, (get_d_dst, get_d_src) in zipped:
         d_model_output = ModelOutput(last_hidden_state=model.ops.alloc(
             trf_data.model_output.last_hidden_state.shape,
             dtype=trf_data.model_output.last_hidden_state.dtype,
         ))
         d_dst = get_d_dst(d_output)
         d_src = get_d_src(d_dst)
         d_src *= grad_factor
         d_model_output["last_hidden_state"] = d_src.reshape(
             trf_data.model_output.last_hidden_state.shape)
         d_trf_datas.append(
             TransformerData(
                 model_output=d_model_output,
                 wordpieces=trf_data.wordpieces,
                 align=trf_data.align,
             ))
     return d_trf_datas
Example #3
0
    def _process_data(self, inputs, return_dict):
        inp_length = inputs[self.main_input_name].shape[1]

        # If <max_length> specified, pad inputs by zeros
        if inp_length < self.max_length:
            for name in inputs:
                shape = inputs[name].shape
                if shape[1] != self.max_length:
                    pad = np.zeros([len(shape), 2], dtype=np.int32)
                    pad[1, 1] = self.max_length - shape[1]
                    inputs[name] = np.pad(inputs[name], pad)

        # OpenVINO >= 2022.1 supports dynamic shapes input.
        if not is_openvino_api_2:
            inputs_info = self.net.input_info
            input_ids = inputs[self.main_input_name]
            if inputs_info[self.main_input_name].input_data.shape[
                    1] != input_ids.shape[1]:
                # Use batch size 1 because we process batch sequently.
                shapes = {
                    key: [1] + list(inputs[key].shape[1:])
                    for key in inputs_info
                }
                logger.info(f"Reshape model to {shapes}")
                self.net.reshape(shapes)
                self.exec_net = None
        elif is_openvino_api_2 and not self.use_dynamic_shapes:
            # TODO
            pass

        if self.exec_net is None:
            self._load_network()

        if is_openvino_api_2:
            outs = self._process_data_api_2022(inputs)
        else:
            outs = self._process_data_api_2021(inputs)

        logits = outs["output"] if "output" in outs else next(
            iter(outs.values()))

        past_key_values = None
        if self.config.architectures[0].endswith(
                "ForConditionalGeneration") and self.config.use_cache:
            past_key_values = [[]]
            for name in outs:
                if name == "output":
                    continue
                if len(past_key_values[-1]) == 4:
                    past_key_values.append([])
                past_key_values[-1].append(torch.tensor(outs[name]))

            past_key_values = tuple([tuple(val) for val in past_key_values])

        # Trunc padded values
        if inp_length != logits.shape[1]:
            logits = logits[:, :inp_length]

        if not return_dict:
            return [logits]

        arch = self.config.architectures[0]
        if arch.endswith("ForSequenceClassification"):
            return SequenceClassifierOutput(logits=logits)
        elif arch.endswith("ForQuestionAnswering"):
            return QuestionAnsweringModelOutput(start_logits=outs["output_s"],
                                                end_logits=outs["output_e"])
        else:
            return ModelOutput(logits=torch.tensor(logits),
                               past_key_values=past_key_values)
 def from_dict(self, msg: Dict[str, Any]) -> "TransformerData":
     self.wordpieces = WordpieceBatch.empty().from_dict(msg["wordpieces"])
     self.model_output = ModelOutput(msg["model_output"])
     self.align = Ragged(*msg["align"])
     return self
 def empty(cls) -> "TransformerData":
     align = Ragged(numpy.zeros((0, ), dtype="i"),
                    numpy.zeros((0, ), dtype="i"))
     return cls(wordpieces=WordpieceBatch.empty(),
                model_output=ModelOutput(),
                align=align)
Example #6
0
 def backprop(d_model_output: ModelOutput) -> ArgsKwargs:
     return ArgsKwargs(
         args=(model_output.last_hidden_state, ),
         kwargs={"grad_tensors": d_model_output.values()},
     )
Example #7
0
    def forward(self, texts, alpha=1.0, inference=False):
        """
        Input: texts and labels (optional)
        Returns: lm_language modelling output, own output dict (clustering_loss, predicted_labels)
        """
        # Language Modeling Part:

        lm_outputs = ModelOutput(
            loss=torch.tensor(0.0, requires_grad=True).to(self.device))

        if not inference and self.do_language_modeling:
            inputs = self.tokenizer(texts,
                                    return_tensors='pt',
                                    padding=True,
                                    truncation=True)

            input_ids = inputs['input_ids'].clone()
            input_ids, true_ids = mask_tokens(input_ids, self.tokenizer)
            inputs['input_ids'] = input_ids

            inputs = inputs.to(self.device)
            true_ids = true_ids.to(self.device)
            lm_outputs = self.lm_model(labels=true_ids, **inputs)

        # Clustering Part:
        inputs = self.tokenizer(texts,
                                return_tensors='pt',
                                padding=True,
                                truncation=True)

        inputs.to(self.device)

        # 0. Obtain embeddings for each input
        input_embeddings = self.embedding_extractor(
            self.lm_model.base_model(**inputs))

        # 1. Compute distances from each input embedding to each centroids
        distances = torch.stack([
            self.metric(embedding.unsqueeze(0), self.centroids)
            for embedding in input_embeddings
        ])
        nearest_centroids = torch.argmin(distances.cpu().clone().detach(),
                                         dim=1)
        distances = torch.transpose(distances, 0,
                                    1)  # => shape (n_centroids, n_samples)

        # 2. Compute the paramterized softmin for each centroid of each distance to each centroid per input sample
        # Find min distances for each centroid
        min_distances = torch.min(distances, dim=1).values
        # Compute exponetials
        exponentials = torch.exp(-alpha *
                                 (distances - min_distances.unsqueeze(1)))
        # Compute softmin
        softmin = exponentials / torch.sum(exponentials, dim=1).unsqueeze(1)

        # 3. Weight the distance between each sample and each centroid
        weighted_distances = distances * softmin

        # 4. Sum over weighted_distances to obtain loss
        clustering_loss = weighted_distances.sum(dim=1).mean()

        # Create clustering output dictionary
        cluster_outputs = ClusterOutput(
            loss=clustering_loss,
            predicted_labels=nearest_centroids.long(),
            embeddings=input_embeddings.cpu().detach())

        return lm_outputs, cluster_outputs