def _create_init_states_from_corpus( self, init_states_corpus: str, tokenizer: PreTrainedTokenizer, save_init_states_to: Optional[str] = None, ) -> ActivationDict: assert ( tokenizer is not None ), "Tokenizer must be provided when creating init states from corpus" corpus: Corpus = Corpus.create(init_states_corpus, tokenizer=tokenizer) activation_names: ActivationNames = [ (layer, name) for layer in range(self.num_layers) for name in ["hx", "cx"] ] extractor = Extractor( self, corpus, activation_names, activations_dir=save_init_states_to, selection_func=final_sen_token, ) init_states = extractor.extract().activation_dict return init_states
def simple_extract( model: LanguageModel, corpus: Corpus, activation_names: ActivationNames, activations_dir: Optional[str] = None, batch_size: int = BATCH_SIZE, selection_func: SelectionFunc = return_all, sen_column: Optional[str] = None, ) -> Tuple[ActivationReader, RemoveCallback]: """Basic extraction method. Parameters ---------- model : LanguageModel Language model that inherits from LanguageModel. corpus : Corpus Corpus containing sentences to be extracted. activation_names : List[tuple[int, str]] List of (layer, activation_name) tuples activations_dir : str, optional Directory to which activations will be written. If not provided the `extract()` method will only return the activations without writing them to disk. selection_func : SelectionFunc Function which determines if activations for a token should be extracted or not. batch_size : int, optional Amount of sentences processed per forward step. Higher batch size increases extraction speed, but should be done accordingly to the amount of available RAM. Defaults to 1. sen_column : str, optional Corpus column that will be tokenized and extracted. Defaults to ``corpus.sen_column``. Returns ------- activation_reader : ActivationReader ActivationReader for the activations that have been extracted. remove_activations : RemoveCallback Callback function that can be executed at the end of a procedure that depends on the extracted activations. Removes all the activations that have been extracted. Takes no arguments. """ extractor = Extractor( model, corpus, activation_names, activations_dir=activations_dir, selection_func=selection_func, batch_size=batch_size, sen_column=sen_column or corpus.sen_column, ) activation_reader = extractor.extract() def remove_activations(): if activations_dir is not None: shutil.rmtree(activations_dir) return activation_reader, remove_activations
from diagnnose.config.config_dict import create_config_dict from diagnnose.corpus import Corpus from diagnnose.extract import Extractor from diagnnose.models import LanguageModel, import_model, set_init_states from diagnnose.tokenizer.create import create_tokenizer if __name__ == "__main__": config_dict = create_config_dict() tokenizer = create_tokenizer(**config_dict["tokenizer"]) corpus: Corpus = Corpus.create(tokenizer=tokenizer, **config_dict["corpus"]) model: LanguageModel = import_model(**config_dict["model"]) set_init_states(model, tokenizer=tokenizer, use_default=True) extractor = Extractor(model, corpus, **config_dict["extract"]) a_reader = extractor.extract()