def simple_extract( model: LanguageModel, activations_dir: str, corpus: Corpus, activation_names: ActivationNames, selection_func: SelectFunc = lambda sen_id, pos, item: True, ) -> None: """ Basic extraction method. """ extractor = Extractor(model, corpus, activations_dir, activation_names) extractor.extract(batch_size=BATCH_SIZE, dynamic_dumping=False, selection_func=selection_func)
def _create_init_states_from_corpus( self, init_states_corpus: str, vocab_path: str, save_init_states_to: Optional[str], ) -> ActivationTensors: corpus: Corpus = import_corpus(init_states_corpus, vocab_path=vocab_path) self.init_states = self.create_zero_state() extractor = Extractor(self, corpus, save_init_states_to) init_states = extractor.extract( create_avg_eos=True, only_return_avg_eos=(save_init_states_to is None)) assert init_states is not None return init_states
from diagnnose.config.arg_parser import create_arg_parser from diagnnose.config.setup import ConfigSetup from diagnnose.corpora.import_corpus import import_corpus_from_path from diagnnose.extractors.base_extractor import Extractor from diagnnose.models.import_model import import_model_from_json from diagnnose.models.language_model import LanguageModel from diagnnose.typedefs.corpus import Corpus if __name__ == '__main__': arg_groups = {'model', 'activations', 'corpus', 'extract'} arg_parser, required_args = create_arg_parser(arg_groups) config_dict = ConfigSetup(arg_parser, required_args, arg_groups).config_dict model: LanguageModel = import_model_from_json(config_dict['model']) corpus: Corpus = import_corpus_from_path(**config_dict['corpus']) extractor = Extractor(model, corpus, **config_dict['activations']) extractor.extract(**config_dict['extract'])
""" Select activations only when they occur on the subject's position. """ return pos == sentence.misc_info["subj_pos"] def pos_4_selection_func(pos: int, token: str, sentence: LabeledSentence): """ Select activations only on position 4. """ return pos == 4 if __name__ == "__main__": required_args = {'model', 'vocab', 'lm_module', 'corpus_path', 'activation_names', 'output_dir'} arg_groups = { 'model': {'model', 'vocab', 'lm_module', 'device'}, 'corpus': {'corpus_path'}, 'init_extract': {'activation_names', 'output_dir', 'init_lstm_states_path'}, 'extract': {'cutoff', 'print_every'}, } argparser = init_argparser() config_object = ConfigSetup(argparser, required_args, arg_groups) config_dict = config_object.config_dict model: LanguageModel = import_model_from_json(**config_dict['model']) corpus: LabeledCorpus = convert_to_labeled_corpus(**config_dict['corpus']) extractor = Extractor(model, corpus, **config_dict['init_extract']) extractor.extract(**config_dict['extract'], selection_func=pos_4_selection_func) # In case you want to extract average eos activations as well, uncomment this line # extractor.extract_average_eos_activations(print_every=config_dict['extract']['print_every'])
from diagnnose.config.arg_parser import create_arg_parser from diagnnose.config.setup import create_config_dict from diagnnose.corpus.import_corpus import import_corpus from diagnnose.extractors.base_extractor import Extractor from diagnnose.models.import_model import import_model from diagnnose.models.lm import LanguageModel from diagnnose.typedefs.corpus import Corpus from diagnnose.vocab import get_vocab_from_config if __name__ == "__main__": arg_groups = { "model", "activations", "corpus", "extract", "init_states", "vocab" } arg_parser, required_args = create_arg_parser(arg_groups) config_dict = create_config_dict(arg_parser, required_args, arg_groups) model: LanguageModel = import_model(config_dict) corpus: Corpus = import_corpus( vocab_path=get_vocab_from_config(config_dict), **config_dict["corpus"]) extractor = Extractor(model, corpus, **config_dict["activations"]) extractor.extract(**config_dict["extract"])