def split_by_doc(self) -> List[TransformerData]: """Split a TransformerData that represents a batch into a list with one TransformerData per Doc. """ flat_spans = [] for doc_spans in self.spans: flat_spans.extend(doc_spans) token_positions = get_token_positions(flat_spans) outputs = [] start = 0 prev_tokens = 0 for doc_spans in self.spans: if len(doc_spans) == 0 or len(doc_spans[0]) == 0: outputs.append(TransformerData.empty()) continue start_i = token_positions[doc_spans[0][0]] end_i = token_positions[doc_spans[-1][-1]] + 1 end = start + len(doc_spans) doc_tokens = self.wordpieces[start:end] doc_align = self.align[start_i:end_i] doc_align.data = doc_align.data - prev_tokens model_output = ModelOutput() last_hidden_state = self.model_output.last_hidden_state for key, output in self.model_output.items(): if isinstance(output, torch.Tensor): model_output[key] = torch2xp(output[start:end]) elif (isinstance(output, tuple) and all(isinstance(t, torch.Tensor) for t in output) and all(t.shape[0] == last_hidden_state.shape[0] for t in output)): model_output[key] = [ torch2xp(t[start:end]) for t in output ] outputs.append( TransformerData( wordpieces=doc_tokens, model_output=model_output, align=doc_align, )) prev_tokens += doc_tokens.input_ids.size start += len(doc_spans) return outputs
def backprop_trf_to_tensor( d_outputs: List[Floats2d]) -> List[TransformerData]: d_trf_datas = [] zipped = zip(trf_datas, d_outputs, backprops) for trf_data, d_output, (get_d_dst, get_d_src) in zipped: d_model_output = ModelOutput(last_hidden_state=model.ops.alloc( trf_data.model_output.last_hidden_state.shape, dtype=trf_data.model_output.last_hidden_state.dtype, )) d_dst = get_d_dst(d_output) d_src = get_d_src(d_dst) d_src *= grad_factor d_model_output["last_hidden_state"] = d_src.reshape( trf_data.model_output.last_hidden_state.shape) d_trf_datas.append( TransformerData( model_output=d_model_output, wordpieces=trf_data.wordpieces, align=trf_data.align, )) return d_trf_datas
def _process_data(self, inputs, return_dict): inp_length = inputs[self.main_input_name].shape[1] # If <max_length> specified, pad inputs by zeros if inp_length < self.max_length: for name in inputs: shape = inputs[name].shape if shape[1] != self.max_length: pad = np.zeros([len(shape), 2], dtype=np.int32) pad[1, 1] = self.max_length - shape[1] inputs[name] = np.pad(inputs[name], pad) # OpenVINO >= 2022.1 supports dynamic shapes input. if not is_openvino_api_2: inputs_info = self.net.input_info input_ids = inputs[self.main_input_name] if inputs_info[self.main_input_name].input_data.shape[ 1] != input_ids.shape[1]: # Use batch size 1 because we process batch sequently. shapes = { key: [1] + list(inputs[key].shape[1:]) for key in inputs_info } logger.info(f"Reshape model to {shapes}") self.net.reshape(shapes) self.exec_net = None elif is_openvino_api_2 and not self.use_dynamic_shapes: # TODO pass if self.exec_net is None: self._load_network() if is_openvino_api_2: outs = self._process_data_api_2022(inputs) else: outs = self._process_data_api_2021(inputs) logits = outs["output"] if "output" in outs else next( iter(outs.values())) past_key_values = None if self.config.architectures[0].endswith( "ForConditionalGeneration") and self.config.use_cache: past_key_values = [[]] for name in outs: if name == "output": continue if len(past_key_values[-1]) == 4: past_key_values.append([]) past_key_values[-1].append(torch.tensor(outs[name])) past_key_values = tuple([tuple(val) for val in past_key_values]) # Trunc padded values if inp_length != logits.shape[1]: logits = logits[:, :inp_length] if not return_dict: return [logits] arch = self.config.architectures[0] if arch.endswith("ForSequenceClassification"): return SequenceClassifierOutput(logits=logits) elif arch.endswith("ForQuestionAnswering"): return QuestionAnsweringModelOutput(start_logits=outs["output_s"], end_logits=outs["output_e"]) else: return ModelOutput(logits=torch.tensor(logits), past_key_values=past_key_values)
def from_dict(self, msg: Dict[str, Any]) -> "TransformerData": self.wordpieces = WordpieceBatch.empty().from_dict(msg["wordpieces"]) self.model_output = ModelOutput(msg["model_output"]) self.align = Ragged(*msg["align"]) return self
def empty(cls) -> "TransformerData": align = Ragged(numpy.zeros((0, ), dtype="i"), numpy.zeros((0, ), dtype="i")) return cls(wordpieces=WordpieceBatch.empty(), model_output=ModelOutput(), align=align)
def backprop(d_model_output: ModelOutput) -> ArgsKwargs: return ArgsKwargs( args=(model_output.last_hidden_state, ), kwargs={"grad_tensors": d_model_output.values()}, )
def forward(self, texts, alpha=1.0, inference=False): """ Input: texts and labels (optional) Returns: lm_language modelling output, own output dict (clustering_loss, predicted_labels) """ # Language Modeling Part: lm_outputs = ModelOutput( loss=torch.tensor(0.0, requires_grad=True).to(self.device)) if not inference and self.do_language_modeling: inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True) input_ids = inputs['input_ids'].clone() input_ids, true_ids = mask_tokens(input_ids, self.tokenizer) inputs['input_ids'] = input_ids inputs = inputs.to(self.device) true_ids = true_ids.to(self.device) lm_outputs = self.lm_model(labels=true_ids, **inputs) # Clustering Part: inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True) inputs.to(self.device) # 0. Obtain embeddings for each input input_embeddings = self.embedding_extractor( self.lm_model.base_model(**inputs)) # 1. Compute distances from each input embedding to each centroids distances = torch.stack([ self.metric(embedding.unsqueeze(0), self.centroids) for embedding in input_embeddings ]) nearest_centroids = torch.argmin(distances.cpu().clone().detach(), dim=1) distances = torch.transpose(distances, 0, 1) # => shape (n_centroids, n_samples) # 2. Compute the paramterized softmin for each centroid of each distance to each centroid per input sample # Find min distances for each centroid min_distances = torch.min(distances, dim=1).values # Compute exponetials exponentials = torch.exp(-alpha * (distances - min_distances.unsqueeze(1))) # Compute softmin softmin = exponentials / torch.sum(exponentials, dim=1).unsqueeze(1) # 3. Weight the distance between each sample and each centroid weighted_distances = distances * softmin # 4. Sum over weighted_distances to obtain loss clustering_loss = weighted_distances.sum(dim=1).mean() # Create clustering output dictionary cluster_outputs = ClusterOutput( loss=clustering_loss, predicted_labels=nearest_centroids.long(), embeddings=input_embeddings.cpu().detach()) return lm_outputs, cluster_outputs