예제 #1
0
    def __init__(
        self,
        save_dir: str,
        corpus: Corpus,
        activation_names: ActivationNames,
        activations_dir: Optional[str] = None,
        test_activations_dir: Optional[str] = None,
        test_corpus: Optional[Corpus] = None,
        model: Optional[LanguageModel] = None,
        selection_func: SelectFunc = lambda sen_id, pos, example: True,
    ) -> None:
        self.save_dir = save_dir
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        activations_dir, test_activations_dir = self._extract_activations(
            save_dir,
            corpus,
            activation_names,
            selection_func,
            activations_dir,
            test_activations_dir,
            test_corpus,
            model,
        )

        self.activation_names = activation_names
        self.data_loader = DataLoader(
            activations_dir,
            corpus,
            test_activations_dir=test_activations_dir,
            test_corpus=test_corpus,
            selection_func=selection_func,
        )
        self.classifier = LogRegCV()
예제 #2
0
    def __init__(
        self,
        save_dir: str,
        corpus: Corpus,
        activation_names: ActivationNames,
        activations_dir: Optional[str] = None,
        test_activations_dir: Optional[str] = None,
        test_corpus: Optional[Corpus] = None,
        model: Optional[LanguageModel] = None,
        train_selection_func: SelectionFunc = lambda sen_id, pos, example:
        True,
        test_selection_func: Optional[SelectionFunc] = None,
        control_task: Optional[ControlTask] = None,
        classifier_type: str = "logreg_torch",
        save_logits: bool = False,
        verbose: int = 0,
    ) -> None:
        self.save_dir = save_dir
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        self.remove_callbacks = []
        activations_dir, test_activations_dir = self._extract_activations(
            save_dir,
            corpus,
            activation_names,
            train_selection_func,
            activations_dir,
            test_activations_dir,
            test_corpus,
            test_selection_func,
            model,
        )

        self.activation_names = activation_names
        self.data_dict: DataDict = {}
        self.data_loader = DataLoader(
            activations_dir,
            corpus,
            test_activations_dir=test_activations_dir,
            test_corpus=test_corpus,
            train_selection_func=train_selection_func,
            test_selection_func=test_selection_func,
            control_task=control_task,
        )
        assert classifier_type in [
            "logreg_torch",
            "logreg_sklearn",
        ], "Classifier type not understood, should be either `logreg_toch` or `logreg_sklearn`"
        self.classifier_type = classifier_type
        self.save_logits = save_logits
        self.verbose = verbose
예제 #3
0
    def __init__(self,
                 corpus: Corpus,
                 activations_dir: str,
                 activation_names: List[ActivationName],
                 save_dir: str,
                 classifier_type: str,
                 calc_class_weights: bool = False) -> None:

        self.activation_names: List[ActivationName] = activation_names
        self.save_dir = save_dir
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        # TODO: Allow own classifier here (should adhere to some base functions, such as .fit())
        self.classifier_type = classifier_type
        self.calc_class_weights = calc_class_weights

        self.data_loader = DataLoader(activations_dir, corpus)
        self.results: ResultsDict = defaultdict(dict)

        self._reset_classifier()
예제 #4
0
    def setUpClass(cls) -> None:
        # Create directory if necessary
        if not os.path.exists(ACTIVATIONS_DIR):
            os.makedirs(ACTIVATIONS_DIR)

        # Create dummy data have reader read it
        create_and_dump_dummy_activations(
            num_sentences=NUM_TEST_SENTENCES, activations_dim=ACTIVATIONS_DIM, max_tokens=5,
            activations_dir=ACTIVATIONS_DIR, activations_name=ACTIVATIONS_NAME, num_classes=2
        )
        cls.data_loader = DataLoader(activations_dir=ACTIVATIONS_DIR)
        cls.num_labels = cls.data_loader.data_len
예제 #5
0
    def setUpClass(cls) -> None:
        # Create directory if necessary
        if not os.path.exists(ACTIVATIONS_DIR):
            os.makedirs(ACTIVATIONS_DIR)

        # Create dummy data have reader read it
        cls.num_labels = create_and_dump_dummy_activations(
            num_sentences=NUM_TEST_SENTENCES,
            activations_dim=ACTIVATIONS_DIM,
            max_sen_len=5,
            activations_dir=ACTIVATIONS_DIR,
            activations_name=ACTIVATIONS_NAME,
            num_classes=2,
        )
        corpus = import_corpus(f"{ACTIVATIONS_DIR}/corpus.tsv")

        cls.data_loader = DataLoader(ACTIVATIONS_DIR, corpus)
예제 #6
0
class DCTrainer:
    """ Trains Diagnostic Classifiers (DC) on extracted activation data.

    For each activation that is part of the provided activation_names
    argument a different classifier will be trained.

    Parameters
    ----------
    corpus : Corpus
        Corpus containing the token labels for each sentence.
    activations_dir : str
        Path to folder containing the activations to train on.
    activation_names : List[ActivationName]
        List of activation names on which classifiers will be trained.
    save_dir : str
        Directory to which trained models will be saved.
    classifier_type : str
        Classifier type, as of now only accepts `logreg`, but more will be added.
    calc_class_weights : bool, optional
        Set to True to calculate the classifier class weights based on
        the corpus class frequencies. Defaults to False.

    Attributes
    ----------
    data_loader : DataLoader
        Class that reads and preprocesses activation data.
    classifier : Classifier
        Current classifier that is being trained.
    results : ResultsDict
        Dictionary containing relevant results. TODO: Add preds to this instead of separate files?
    """
    def __init__(self,
                 corpus: Corpus,
                 activations_dir: str,
                 activation_names: List[ActivationName],
                 save_dir: str,
                 classifier_type: str,
                 calc_class_weights: bool = False) -> None:

        self.activation_names: List[ActivationName] = activation_names
        self.save_dir = save_dir
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        # TODO: Allow own classifier here (should adhere to some base functions, such as .fit())
        self.classifier_type = classifier_type
        self.calc_class_weights = calc_class_weights

        self.data_loader = DataLoader(activations_dir, corpus)
        self.results: ResultsDict = defaultdict(dict)

        self._reset_classifier()

    def train(self,
              data_subset_size: int = -1,
              train_test_split: float = 0.9) -> None:
        start_t = time()

        for a_name in self.activation_names:
            data_dict = self.data_loader.create_data_split(
                a_name, data_subset_size, train_test_split)

            # Calculate class weights
            if self.calc_class_weights:
                classes, class_freqs = np.unique(data_dict['train_y'],
                                                 return_counts=True)
                norm = class_freqs.sum()  # Norm factor
                class_weight = {
                    classes[i]: class_freqs[i] / norm
                    for i in range(len(class_freqs))
                }
                self.classifier.class_weight = class_weight

            # Train
            self.fit_data(data_dict['train_x'], data_dict['train_y'], a_name)
            pred_y = self.eval_classifier(data_dict['test_x'],
                                          data_dict['test_y'], a_name)

            self.save_classifier(pred_y, a_name)
            self._reset_classifier()

        self.log_results(start_t)

    def _reset_classifier(self) -> None:
        self.classifier = {
            'logreg': LogRegCV(),
            'svm': None,
        }[self.classifier_type]

    def fit_data(self, train_x: np.ndarray, train_y: np.ndarray,
                 activation_name: ActivationName) -> None:
        print(f'\nStarting fitting model on {activation_name}...')

        start_time = time()
        self.classifier.fit(train_x, train_y)

        print(f'Fitting done in {time() - start_time:.2f}s')

    # TODO: Add more evaluation metrics here
    def eval_classifier(self, test_x: np.ndarray, test_y: np.ndarray,
                        activation_name: ActivationName) -> np.ndarray:
        pred_y = self.classifier.predict(test_x)

        acc = accuracy_score(test_y, pred_y)

        print(f'{activation_name} acc.:', acc)

        self.results[activation_name]['acc'] = acc

        return pred_y

    def save_classifier(self, pred_y: np.ndarray,
                        activation_name: ActivationName) -> None:
        l, name = activation_name

        preds_path = os.path.join(self.save_dir, f'{name}_l{l}_preds.pickle')
        model_path = os.path.join(self.save_dir, f'{name}_l{l}.joblib')

        dump_pickle(pred_y, preds_path)
        joblib.dump(self.classifier, model_path)

    @staticmethod
    def load_classifier(path: str) -> Any:
        return joblib.load(path)

    def log_results(self, start_t: float) -> None:
        total_time = time() - start_t
        m, s = divmod(total_time, 60)

        print(f'Total classification time took {m:.0f}m {s:.1f}s')

        log = {
            'activation_names': self.activation_names,
            'classifier_type': self.classifier_type,
            'results': self.results,
            'total_time': total_time,
        }

        log_path = os.path.join(self.save_dir, 'log.pickle')
        dump_pickle(log, log_path)
예제 #7
0
class DCTrainer:
    """ Trains Diagnostic Classifiers (DC) on extracted activation data.

    For each activation that is part of the provided activation_names
    argument a different classifier will be trained.

    Parameters
    ----------
    save_dir : str, optional
        Directory to which trained models will be saved, if provided.
    corpus : Corpus
        Corpus containing the token labels for each sentence.
    activation_names : List[ActivationName]
        List of activation names on which classifiers will be trained.
    activations_dir : str, optional
        Path to folder containing the activations to train on. If not
        provided newly extracted activations will be saved to
        `save_dir`.
    test_activations_dir : str, optional
        Directory containing the extracted test activations. If not
        provided the train activation set will be split and partially
        used as test set.
    test_corpus : Corpus, optional
        Corpus containing the test labels for each sentence. If
        provided without `test_activations_dir` newly extracted
        activations will be saved to `save_dir`.
    model : LanguageModel, optional
        LanguageModel that should be provided if new activations need
        to be extracted prior to training the classifiers.
    selection_func : SelectFunc, optional
        Selection function that determines whether a corpus item should
        be taken into account for training. If such a function has been
        used during extraction, make sure to pass it along here as well.

    Attributes
    ----------
    data_loader : DataLoader
        Class that reads and preprocesses activation data.
    classifier : Classifier
        Current classifier that is being trained.
    """

    def __init__(
        self,
        save_dir: str,
        corpus: Corpus,
        activation_names: ActivationNames,
        activations_dir: Optional[str] = None,
        test_activations_dir: Optional[str] = None,
        test_corpus: Optional[Corpus] = None,
        model: Optional[LanguageModel] = None,
        selection_func: SelectFunc = lambda sen_id, pos, example: True,
    ) -> None:
        self.save_dir = save_dir
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        activations_dir, test_activations_dir = self._extract_activations(
            save_dir,
            corpus,
            activation_names,
            selection_func,
            activations_dir,
            test_activations_dir,
            test_corpus,
            model,
        )

        self.activation_names = activation_names
        self.data_loader = DataLoader(
            activations_dir,
            corpus,
            test_activations_dir=test_activations_dir,
            test_corpus=test_corpus,
            selection_func=selection_func,
        )
        self.classifier = LogRegCV()

    def train(
        self,
        calc_class_weights: bool = False,
        data_subset_size: int = -1,
        train_test_split: float = 0.9,
    ) -> None:
        """ Trains DCs on multiple activation names.

        Parameters
        ----------
        calc_class_weights : bool, optional
            Set to True to calculate the classifier class weights based on
            the corpus class frequencies. Defaults to False.
        data_subset_size : int, optional
            Size of the subset on which training will be performed. Defaults
            to the full set of activations.
        train_test_split : float, optional
            Percentage of the train/test split. If separate test
            activations are provided this split won't be used.
            Defaults to 0.9/0.1.
        """
        for activation_name in self.activation_names:
            self._train(
                activation_name,
                calc_class_weights=calc_class_weights,
                data_subset_size=data_subset_size,
                train_test_split=train_test_split,
            )

    def _train(
        self,
        activation_name: ActivationName,
        calc_class_weights: bool = False,
        data_subset_size: int = -1,
        train_test_split: float = 0.9,
    ) -> None:
        """ Initiates training the DC on 1 activation type. """
        self._reset_classifier()

        data_dict = self.data_loader.create_data_split(
            activation_name, data_subset_size, train_test_split
        )

        # Calculate class weights
        if calc_class_weights:
            self._set_class_weights(data_dict["train_y"])

        # Train
        self._fit(data_dict["train_x"], data_dict["train_y"], activation_name)
        results = self._eval(data_dict["test_x"], data_dict["test_y"])

        if self.save_dir is not None:
            self._save(results, activation_name)

    def _fit(
        self, train_x: Tensor, train_y: Tensor, activation_name: ActivationName
    ) -> None:
        start_time = time()
        print(f"\nStarting fitting model on {activation_name}...")

        self.classifier.fit(train_x, train_y)

        print(f"Fitting done in {time() - start_time:.2f}s")

    def _eval(self, test_x: Tensor, test_y: Tensor) -> Dict[str, Any]:
        pred_y = self.classifier.predict(test_x)

        acc = accuracy_score(test_y, pred_y)
        cm = confusion_matrix(test_y, pred_y)

        results = {"accuracy": acc, "confusion matrix": cm}
        for k, v in results.items():
            print(k, v, "", sep="\n")
        results["pred_y"] = pred_y

        return results

    def _save(self, results: Dict[str, Any], activation_name: ActivationName) -> None:
        l, name = activation_name

        preds_path = os.path.join(self.save_dir, f"{name}_l{l}_results.pickle")
        model_path = os.path.join(self.save_dir, f"{name}_l{l}.joblib")

        dump_pickle(results, preds_path)
        joblib.dump(self.classifier, model_path)

    def _reset_classifier(self) -> None:
        self.classifier = LogRegCV()

    def _set_class_weights(self, train_y: Tensor) -> None:
        classes, class_freqs = torch.unique(train_y, return_counts=True)
        norm = class_freqs.sum().item()
        class_weight = {
            classes[i].item(): class_freqs[i].item() / norm
            for i in range(len(class_freqs))
        }
        self.classifier.class_weight = class_weight

    @staticmethod
    def _extract_activations(
        save_dir: str,
        corpus: Corpus,
        activation_names: ActivationNames,
        selection_func: SelectFunc,
        activations_dir: Optional[str],
        test_activations_dir: Optional[str],
        test_corpus: Optional[Corpus],
        model: Optional[LanguageModel],
    ) -> Tuple[str, Optional[str]]:
        if activations_dir is None:
            activations_dir = os.path.join(save_dir, "activations")
            simple_extract(
                model, activations_dir, corpus, activation_names, selection_func
            )

        if test_corpus is not None and test_activations_dir is None:
            test_activations_dir = os.path.join(save_dir, "test_activations")
            simple_extract(
                model,
                test_activations_dir,
                test_corpus,
                activation_names,
                selection_func,
            )

        return activations_dir, test_activations_dir
예제 #8
0
class DCTrainer:
    """Trains Diagnostic Classifiers (DC) on extracted activation data.

    For each activation that is part of the provided activation_names
    argument a different classifier will be trained.

    Parameters
    ----------
    save_dir : str
        Directory to which trained models will be saved.
    corpus : Corpus
        Corpus containing the token labels for each sentence.
    activation_names : ActivationNames
        List of activation names on which classifiers will be trained.
    activations_dir : str, optional
        Path to folder containing the activations to train on. If not
        provided newly extracted activations will be saved to
        `save_dir`.
    test_activations_dir : str, optional
        Directory containing the extracted test activations. If not
        provided the train activation set will be split and partially
        used as test set.
    test_corpus : Corpus, optional
        Corpus containing the test labels for each sentence. If
        provided without `test_activations_dir` newly extracted
        activations will be saved to `save_dir`.
    model : LanguageModel, optional
        LanguageModel that should be provided if new activations need
        to be extracted prior to training the classifiers.
    train_selection_func : SelectFunc, optional
        Selection function that determines whether a corpus item should
        be taken into account for training. If not provided all
        extracted activations will be used and split into a random
        train/test split.
    test_selection_func : SelectFunc, optional
        Selection function that determines whether a corpus item should
        be taken into account for testing. If not provided all
        extracted activations will be used and split into a random
        train/test split.
    classifier_type : str, optional
        Either `logreg_torch`, using a torch logreg model, or
        `logreg_sklearn`, using a LogisticRegressionCV model of sklearn.
    control_task : ControlTask, optional
        Control task function of Hewitt et al. (2019), mapping a corpus
        item to a random label. If not provided the corpus labels will
        be used instead.
    save_logits : bool, optional
        Toggle to store the output logits of the classifier on the test
        set. Defaults to False.
    verbose : int, optional
        Set to any positive number for verbosity. Defaults to 0.

    Attributes
    ----------
    data_loader : DataLoader
        Class that reads and preprocesses activation data.
    classifier : Classifier
        Current classifier that is being trained.
    """
    def __init__(
        self,
        save_dir: str,
        corpus: Corpus,
        activation_names: ActivationNames,
        activations_dir: Optional[str] = None,
        test_activations_dir: Optional[str] = None,
        test_corpus: Optional[Corpus] = None,
        model: Optional[LanguageModel] = None,
        train_selection_func: SelectionFunc = lambda sen_id, pos, example:
        True,
        test_selection_func: Optional[SelectionFunc] = None,
        control_task: Optional[ControlTask] = None,
        classifier_type: str = "logreg_torch",
        save_logits: bool = False,
        verbose: int = 0,
    ) -> None:
        self.save_dir = save_dir
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        self.remove_callbacks = []
        activations_dir, test_activations_dir = self._extract_activations(
            save_dir,
            corpus,
            activation_names,
            train_selection_func,
            activations_dir,
            test_activations_dir,
            test_corpus,
            test_selection_func,
            model,
        )

        self.activation_names = activation_names
        self.data_dict: DataDict = {}
        self.data_loader = DataLoader(
            activations_dir,
            corpus,
            test_activations_dir=test_activations_dir,
            test_corpus=test_corpus,
            train_selection_func=train_selection_func,
            test_selection_func=test_selection_func,
            control_task=control_task,
        )
        assert classifier_type in [
            "logreg_torch",
            "logreg_sklearn",
        ], "Classifier type not understood, should be either `logreg_toch` or `logreg_sklearn`"
        self.classifier_type = classifier_type
        self.save_logits = save_logits
        self.verbose = verbose

    def train(
        self,
        calc_class_weights: bool = False,
        data_subset_size: int = -1,
        train_test_split: float = 0.9,
        store_activations: bool = True,
        rank: Optional[int] = None,
        max_epochs: int = 10,
        classifier_name: Optional[str] = None,
    ) -> Dict[ActivationName, Any]:
        """Trains DCs on multiple activation names.

        Parameters
        ----------
        calc_class_weights : bool, optional
            Set to True to calculate the classifier class weights based on
            the corpus class frequencies. Defaults to False.
        data_subset_size : int, optional
            Size of the subset on which training will be performed.
            Defaults to the full set of activations.
        train_test_split : float, optional
            Percentage of the train/test split. If separate test
            activations are provided this split won't be used.
            Defaults to 0.9/0.1.
        store_activations : bool, optional
            Set to True to store the extracted activations. Defaults to
            True.
        rank : int, optional
            Matrix rank of the linear classifier. Defaults to the full
            rank if not provided.
        max_epochs : int, optional
            Maximum number of training epochs used by skorch.
            Defaults to 10.
        classifier_name : str, optional
            Name for the trained classifier that is saved. If not
            provided `{name}_l{layer}.pt` will be used.
        """

        full_results_dict = {}

        for activation_name in self.activation_names:
            results_dict = self._train(
                activation_name,
                calc_class_weights,
                data_subset_size,
                train_test_split,
                rank,
                max_epochs,
                classifier_name,
            )
            full_results_dict[activation_name] = results_dict

        if not store_activations:
            for remove_callback in self.remove_callbacks:
                remove_callback()

        return full_results_dict

    def _train(
        self,
        activation_name: ActivationName,
        calc_class_weights: bool,
        data_subset_size: int,
        train_test_split: float,
        rank: Optional[int],
        max_epochs: int,
        classifier_name: Optional[str],
    ) -> Dict[str, Any]:
        """ Initiates training the DC on 1 activation type. """
        self.data_dict = self.data_loader.create_data_split(
            activation_name, data_subset_size, train_test_split)

        self._reset_classifier(rank, max_epochs)
        if self.verbose > 0:
            train_size = self.data_dict["train_x"].size(0)
            test_size = self.data_dict["test_x"].size(0)
            print(f"train/test: {train_size}/{test_size}")

        # Calculate class weights
        if calc_class_weights:
            self._set_class_weights(self.data_dict["train_y"])

        # Train
        self._fit(activation_name)
        results_dict = self._eval(self.data_dict["test_y"])

        self._save_classifier(activation_name, classifier_name)

        if self.data_dict["train_y_control"] is not None:
            self._control_task(rank, results_dict)

        self._save_results(results_dict, activation_name)

        return results_dict

    def _fit(self, activation_name: ActivationName) -> None:
        start_time = time()
        if self.verbose > 0:
            print(f"\nStarting fitting model on {activation_name}...")

        self.classifier.fit(self.data_dict["train_x"],
                            self.data_dict["train_y"])

        if self.verbose > 0:
            print(f"Fitting done in {time() - start_time:.2f}s")

    def _eval(self, labels: Tensor) -> Dict[str, Any]:
        pred_y = self.classifier.predict(self.data_dict["test_x"])

        acc = metrics.accuracy_score(labels, pred_y)
        f1 = metrics.f1_score(labels, pred_y, average="micro")
        mcc = metrics.matthews_corrcoef(labels, pred_y)
        cm = metrics.confusion_matrix(labels, pred_y)

        results_dict = {
            "accuracy": acc,
            "f1": f1,
            "mcc": mcc,
            "confusion_matrix": cm
        }

        if self.save_logits and self.classifier_type == "logreg_torch":
            logits = self.classifier.infer(self.data_dict["test_x"],
                                           create_softmax=False)
            results_dict["logits"] = logits.detach()
        elif self.save_logits and self.classifier_type == "logreg_sklearn":
            logits = self.classifier.predict_proba(self.data_dict["test_x"])
            results_dict["logits"] = torch.from_numpy(logits)

        return results_dict

    def _control_task(self, rank: Optional[int],
                      results_dict: Dict[str, Any]) -> None:
        if self.verbose > 0:
            print("Starting fitting the control task...")
        self._reset_classifier(rank=rank)
        self.classifier.fit(self.data_dict["train_x"],
                            self.data_dict["train_y_control"])

        results_dict_control = self._eval(self.data_dict["test_y_control"])
        for k, v in results_dict_control.items():
            results_dict[f"{k}_control"] = v
        results_dict["selectivity"] = (results_dict["accuracy"] -
                                       results_dict["accuracy_control"])

    def _save_classifier(self, activation_name: ActivationName,
                         classifier_name: Optional[str]):
        if self.save_dir is not None:
            l, name = activation_name
            fn = classifier_name if classifier_name else f"{name}_l{l}"
            if self.classifier_type == "logreg_torch":
                model_path = os.path.join(self.save_dir, fn + ".pt")
                torch.save(self.classifier.module.state_dict(), model_path)
            elif self.classifier_type == "logreg_sklearn":
                model_path = os.path.join(self.save_dir, fn + ".joblib")
                joblib.dump(self.classifier, model_path)

    def _save_results(self, results_dict: Dict[str, Any],
                      activation_name: ActivationName) -> None:
        if self.verbose > 0:
            for k, v in results_dict.items():
                print(k, v, "", sep="\n")
            print("Label vocab:", self.data_loader.label_vocab.itos)

        if self.save_dir is not None:
            l, name = activation_name
            preds_path = os.path.join(self.save_dir,
                                      f"{name}_l{l}_results.pickle")
            dump_pickle(results_dict, preds_path)

    def _reset_classifier(self, rank: Optional[int], max_epochs: int) -> None:
        if self.classifier_type == "logreg_torch":
            ninp = self.data_dict["train_x"].size(1)
            nout = len(self.data_loader.label_vocab)
            self.classifier = L1NeuralNetClassifier(
                LogRegModule(ninp=ninp, nout=nout, rank=rank),
                lr=0.01,
                max_epochs=max_epochs,
                verbose=self.verbose,
                optimizer=torch.optim.Adam,
                lambda1=0.005,
            )
        elif self.classifier_type == "logreg_sklearn":
            self.classifier = LogisticRegressionCV(tol=1e-2,
                                                   max_iter=max_epochs)

    # TODO: comply with skorch
    def _set_class_weights(self, labels: Tensor) -> None:
        classes, class_freqs = torch.unique(labels, return_counts=True)
        norm = class_freqs.sum().item()
        class_weight = {
            classes[i].item(): class_freqs[i].item() / norm
            for i in range(len(class_freqs))
        }
        self.classifier.class_weight = class_weight

    def _extract_activations(
        self,
        save_dir: str,
        corpus: Corpus,
        activation_names: ActivationNames,
        selection_func: SelectionFunc,
        activations_dir: Optional[str],
        test_activations_dir: Optional[str],
        test_corpus: Optional[Corpus],
        test_selection_func: Optional[SelectionFunc],
        model: Optional[LanguageModel],
    ) -> Tuple[str, Optional[str]]:
        if activations_dir is None:
            # We combine the 2 selection funcs to extract train and test activations simultaneously.
            if test_corpus is None and test_selection_func is not None:

                def new_selection_func(idx, pos, item):
                    return selection_func(
                        idx, pos, item) or test_selection_func(idx, pos, item)

            else:
                new_selection_func = selection_func

            activations_dir = os.path.join(save_dir, "activations")
            remove_callback = simple_extract(
                model,
                corpus,
                activation_names,
                activations_dir=activations_dir,
                selection_func=new_selection_func,
            )
            self.remove_callbacks.append(remove_callback)

        # If a separate test_corpus is provided we extract these activations separately.
        if test_corpus is not None and test_activations_dir is None:
            test_activations_dir = os.path.join(save_dir, "test_activations")
            remove_callback = simple_extract(
                model,
                test_corpus,
                activation_names,
                activations_dir=test_activations_dir,
                selection_func=test_selection_func
                or (lambda sen_id, pos, example: True),
            )
            self.remove_callbacks.append(remove_callback)

        return activations_dir, test_activations_dir