def __init__(self, num_classes: int, encoder: torch.nn.Module, freeze_encoder: bool, class_weights: Optional[torch.Tensor]): super().__init__() self.num_classes = num_classes self.encoder = encoder self.freeze_encoder = freeze_encoder self.class_weights = class_weights self.encoder.eval() self.classifier_head = SSLEvaluator(n_input=get_encoder_output_dim( self.encoder), n_hidden=None, n_classes=num_classes, p=0.20) if self.num_classes == 2: self.train_metrics = ModuleList([ AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(), Accuracy05() ]) self.val_metrics = ModuleList([ AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(), Accuracy05() ]) else: # Note that for multi-class, Accuracy05 is the standard multi-class accuracy. self.train_metrics = ModuleList([Accuracy05()]) self.val_metrics = ModuleList([Accuracy05()])
def __init__(self, learning_rate: float, class_weights: Optional[torch.Tensor] = None, **kwargs: Any) -> None: """ Creates a hook to evaluate a linear model on top of an SSL embedding. :param class_weights: The class weights to use when computing the cross entropy loss. If set to None, no weighting will be done. :param length_linear_head_loader: The maximum number of batches in the dataloader for the linear head. """ super().__init__(**kwargs) self.weight_decay = 1e-4 self.learning_rate = learning_rate self.train_metrics: List[Metric] = [AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(), Accuracy05()] \ if self.num_classes == 2 else [Accuracy05()] self.val_metrics: List[Metric] = [AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(), Accuracy05()] \ if self.num_classes == 2 else [Accuracy05()] self.class_weights = class_weights self.evaluator_state: Optional[OrderedDict] = None self.optimizer_state: Optional[OrderedDict] = None
def _get_metrics_computers(self) -> ModuleList: """ Gets the objects that compute metrics for the present kind of models, for a single prediction target. """ if self.is_classification_model: return ModuleList([Accuracy05(), AccuracyAtOptimalThreshold(), OptimalThreshold(), FalsePositiveRateOptimalThreshold(), FalseNegativeRateOptimalThreshold(), AreaUnderRocCurve(), AreaUnderPrecisionRecallCurve(), BinaryCrossEntropy()]) else: return ModuleList([MeanAbsoluteError(), MeanSquaredError(), ExplainedVariance()])
def create_metric_computers(self) -> ModuleDict: return ModuleDict({MetricsDict.DEFAULT_HUE_KEY: ModuleList([Accuracy05()])})