def _create_output_classes(self, sen_ids: ActivationIndex) -> Tensor: classes: List[List[int]] = [] for i, sen_id in enumerate(activation_index_to_iterable(sen_ids)): sen = self.corpus[sen_id].sen[1:] tokens = [self.corpus.vocab.stoi[w] for w in sen] classes.append(tokens) if i > 0: assert len(tokens) == len( classes[0] ), "Unequal sentence lengths are not supported yet" return torch.tensor(classes)
def _extract_activations(self, activations_dir: str, sen_ids: ActivationIndex, corpus: Corpus) -> None: activation_names = self._get_activation_names() if sen_ids.stop is None: sen_ids = slice(sen_ids.start, len(corpus), sen_ids.step) sen_id_range = activation_index_to_iterable(sen_ids) def selection_func(sen_id, _pos, _item): return sen_id in sen_id_range simple_extract( self.model, activations_dir, corpus, activation_names, selection_func=selection_func, )
def plot_by_sen_id( self, sen_ids: ActivationIndex, activations_dir: Optional[str] = None, avg_decs: bool = False, extra_classes: Optional[List[int]] = None, arr_pickle: Optional[str] = None, save_arr_as: Optional[str] = None, save_plot_as: Optional[str] = None, ) -> Tensor: if arr_pickle is not None: arr: Tensor = torch.load(arr_pickle) else: arr = self.calc_by_sen_id( sen_ids, activations_dir=activations_dir, extra_classes=extra_classes, save_arr_as=save_arr_as, ) if avg_decs: avg_arr = torch.mean(arr, dim=0) self.plot_attention(avg_arr, save_plot_as=save_plot_as) else: sen_ids = activation_index_to_iterable(sen_ids) batch_size = arr.size(0) y_idx = -len(extra_classes or []) - 1 for i in range(batch_size): self.plot_config.update({ "xtext": self.corpus[sen_ids[i]].sen[1:], "ytext": self.corpus[sen_ids[i]].sen[:y_idx], }) self.plot_attention(arr[i]) return arr
def _read_decoder( self, classes: Optional[ActivationIndex], batch_size: int, decoder_path: Optional[str] = None, ) -> LinearDecoder: if decoder_path is not None: classifier = joblib.load(decoder_path) decoder_w = classifier.coef_ decoder_b = classifier.intercept_ else: decoder_w, decoder_b = import_decoder_from_model(self.model) # Create tensor of relevant decoder classes. if isinstance(classes, int): classes = torch.tensor([classes]) elif isinstance(classes, list): classes = torch.tensor(classes) elif isinstance(classes, ndarray): classes = torch.from_numpy(classes) elif isinstance(classes, slice): classes = torch.tensor(activation_index_to_iterable(classes)) elif classes is None: classes = torch.tensor([]).to(torch.long) if len(classes.shape) == 1: classes = classes.repeat(batch_size, 1) else: assert classes.size(0) == batch_size, ( f"First dimension of classes not equal to batch_size:" f" {classes.size(0)} != {batch_size} (bsz)") decoder_w = decoder_w[classes].permute(0, 2, 1) decoder_b = decoder_b[classes] return decoder_w, decoder_b
def __getitem__(self, key: ActivationKey) -> Union[Tensor, Tuple[Tensor, ...]]: """Allows for concise and efficient indexing of activations. The ``key`` argument should be either an ``ActivationIndex`` (i.e. an iterable that can be used to index a tensor), or a ``(index, activation_name)`` tuple. An ``activation_name`` is a tuple of shape ``(layer, name)``. If multiple activation_names have been extracted the ``activation_name`` must be provided, otherwise it can be left out. The return value is a tuple of tensors, with each tensor of shape (sen_len, nhid). Example usage: .. code-block:: python activation_reader = ActivationReader( dir, activation_names=[(0, "hx"), (1, "hx")], **kwargs ) # activation_name must be passed because ActivationReader # contains two activation_names. activations_first_sen = activation_reader[0, (1, "hx")] all_activations = activation_reader[:, (1, "hx")] activation_reader2 = ActivationReader( dir, activation_names=[(1, "hx")], **kwargs ) # activation_name can be left implicit. activations_first_10_sens = activation_reader2[:10] Parameters ---------- key : ActivationKey ``ActivationIndex`` or ``(index, activation_name)``, as explained above. Returns ------- split_activations : Tensor | Tuple[Tensor, ...] Tensor, if ``self.cat_activations`` is set to True. Otherwise a tuple of tensors, with each item corresponding to the extracted activations of a specific sentence. .. automethod:: __getitem__ """ if isinstance(key, tuple): index, activation_name = key else: assert ( len(self.activation_names) == 1 ), "Activation name must be provided if multiple activations have been extracted" index = key activation_name = self.activation_names[0] iterable_index = activation_index_to_iterable( index, len(self.activation_ranges)) ranges = [self.activation_ranges[idx] for idx in iterable_index] sen_indices = torch.cat([torch.arange(*r) for r in ranges]).to(torch.long) activations = self.activations(activation_name)[sen_indices] if self.cat_activations: return activations lengths = [x[1] - x[0] for x in ranges] split_activations: Tuple[Tensor, ...] = torch.split(activations, lengths) return split_activations