def _extract_activations( self, save_dir: str, corpus: Corpus, activation_names: ActivationNames, selection_func: SelectionFunc, activations_dir: Optional[str], test_activations_dir: Optional[str], test_corpus: Optional[Corpus], test_selection_func: Optional[SelectionFunc], model: Optional[LanguageModel], ) -> Tuple[str, Optional[str]]: if activations_dir is None: # We combine the 2 selection funcs to extract train and test activations simultaneously. if test_corpus is None and test_selection_func is not None: def new_selection_func(idx, pos, item): return selection_func( idx, pos, item) or test_selection_func(idx, pos, item) else: new_selection_func = selection_func activations_dir = os.path.join(save_dir, "activations") remove_callback = simple_extract( model, corpus, activation_names, activations_dir=activations_dir, selection_func=new_selection_func, ) self.remove_callbacks.append(remove_callback) # If a separate test_corpus is provided we extract these activations separately. if test_corpus is not None and test_activations_dir is None: test_activations_dir = os.path.join(save_dir, "test_activations") remove_callback = simple_extract( model, test_corpus, activation_names, activations_dir=test_activations_dir, selection_func=test_selection_func or (lambda sen_id, pos, example: True), ) self.remove_callbacks.append(remove_callback) return activations_dir, test_activations_dir
def _calc_final_hidden( self, corpus: Corpus, selection_func: SelectionFunc, ) -> Tensor: activation_name = (self.model.top_layer, "hx") activation_reader, _ = simple_extract( self.model, corpus, [activation_name], batch_size=len(corpus), selection_func=selection_func, ) activations = activation_reader.activation_dict[activation_name] return activations
def _calc_final_hidden( self, corpus: Corpus, selection_func: SelectionFunc, sen_column: str = "sen", ) -> Tensor: activation_name = (self.model.top_layer, "hx") activation_reader, _ = simple_extract( self.model, corpus, [activation_name], batch_size=len(corpus), selection_func=selection_func, sen_column=sen_column, ) activations = torch.cat(activation_reader[:, activation_name], dim=0) return activations
def _create_activation_reader( self, model: Optional[LanguageModel], corpus: Corpus, activations_dir: str, selection_func: Optional[SelectionFunc], ) -> ActivationReader: if activations_dir is None or self.create_new_activations: assert model is not None activation_reader, _ = simple_extract( model, corpus, self.activation_names, activations_dir=activations_dir, selection_func=selection_func or return_all, ) activation_reader.cat_activations = True else: activation_reader = ActivationReader(activations_dir, cat_activations=True) return activation_reader