def __init__(self, num_classes: int, encoder: torch.nn.Module,
              freeze_encoder: bool, class_weights: Optional[torch.Tensor]):
     super().__init__()
     self.num_classes = num_classes
     self.encoder = encoder
     self.freeze_encoder = freeze_encoder
     self.class_weights = class_weights
     self.encoder.eval()
     self.classifier_head = SSLEvaluator(n_input=get_encoder_output_dim(
         self.encoder),
                                         n_hidden=None,
                                         n_classes=num_classes,
                                         p=0.20)
     if self.num_classes == 2:
         self.train_metrics = ModuleList([
             AreaUnderRocCurve(),
             AreaUnderPrecisionRecallCurve(),
             Accuracy05()
         ])
         self.val_metrics = ModuleList([
             AreaUnderRocCurve(),
             AreaUnderPrecisionRecallCurve(),
             Accuracy05()
         ])
     else:
         # Note that for multi-class, Accuracy05 is the standard multi-class accuracy.
         self.train_metrics = ModuleList([Accuracy05()])
         self.val_metrics = ModuleList([Accuracy05()])
    def __init__(self,
                 learning_rate: float,
                 class_weights: Optional[torch.Tensor] = None,
                 **kwargs: Any) -> None:
        """
        Creates a hook to evaluate a linear model on top of an SSL embedding.

        :param class_weights: The class weights to use when computing the cross entropy loss. If set to None,
                              no weighting will be done.
        :param length_linear_head_loader: The maximum number of batches in the dataloader for the linear head.
        """

        super().__init__(**kwargs)
        self.weight_decay = 1e-4
        self.learning_rate = learning_rate

        self.train_metrics: List[Metric] = [AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(),
                                            Accuracy05()] \
            if self.num_classes == 2 else [Accuracy05()]
        self.val_metrics: List[Metric] = [AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(),
                                          Accuracy05()] \
            if self.num_classes == 2 else [Accuracy05()]
        self.class_weights = class_weights
        self.evaluator_state: Optional[OrderedDict] = None
        self.optimizer_state: Optional[OrderedDict] = None
 def _get_metrics_computers(self) -> ModuleList:
     """
     Gets the objects that compute metrics for the present kind of models, for a single prediction target.
     """
     if self.is_classification_model:
         return ModuleList([Accuracy05(),
                            AccuracyAtOptimalThreshold(),
                            OptimalThreshold(),
                            FalsePositiveRateOptimalThreshold(),
                            FalseNegativeRateOptimalThreshold(),
                            AreaUnderRocCurve(),
                            AreaUnderPrecisionRecallCurve(),
                            BinaryCrossEntropy()])
     else:
         return ModuleList([MeanAbsoluteError(), MeanSquaredError(), ExplainedVariance()])
예제 #4
0
 def create_metric_computers(self) -> ModuleDict:
     return ModuleDict({MetricsDict.DEFAULT_HUE_KEY: ModuleList([Accuracy05()])})