class BinaryHardcodedTask(Task): name: str labels: Any loss: nn.Module = None per_sample_loss: nn.Module = None available_func: Callable = positive_values inputs: Any = None activation: Optional[nn.Module] = None decoder: Decoder = None module: nn.Module = Identity() metrics: Tuple[str, TorchMetric] = ()
class MultilabelClassificationTask(Task): """ Represents a classification task. Labels should be integers from 0 to N-1, where N is the number of classes * activation - `nn.Softmax(dim=-1)` * loss - `nn.CrossEntropyLoss()` """ name: str labels: Any loss: nn.Module = nn.BCEWithLogitsLoss(reduction='mean') per_sample_loss: nn.Module = ReducedPerSample( nn.BCEWithLogitsLoss(reduction='none'), torch.mean) available_func: Callable = positive_values inputs: Any = None activation: Optional[nn.Module] = nn.Sigmoid() decoder: Decoder = field(default_factory=MultilabelClassificationDecoder) module: nn.Module = Identity() metrics: Tuple[str, TorchMetric] = field( default_factory=get_default_multilabel_classification_metrics)
class BinaryClassificationTask(Task): """ Represents a normal binary classification task. Labels should be between 0 and 1. * activation - `nn.Sigmoid()` * loss - ``nn.BCEWithLogitsLoss()` """ name: str labels: Any loss: nn.Module = nn.BCEWithLogitsLoss(reduction='mean') per_sample_loss: nn.Module = ReducedPerSample( nn.BCEWithLogitsLoss(reduction='none'), reduction=torch.mean) available_func: Callable = positive_values inputs: Any = None activation: Optional[nn.Module] = nn.Sigmoid() decoder: Decoder = field(default_factory=BinaryDecoder) module: nn.Module = Identity() metrics: Tuple[str, TorchMetric] = field( default_factory=get_default_binary_metrics)
class ClassificationTask(Task): """ Represents a classification task. Labels should be integers from 0 to N-1, where N is the number of classes * activation - `nn.Softmax()` * loss - `nn.CrossEntropyLoss()` """ name: str labels: Sequence loss: nn.Module = nn.CrossEntropyLoss(reduction='mean') per_sample_loss: nn.Module = ReducedPerSample( nn.CrossEntropyLoss(reduction='none'), torch.mean) available_func: Callable = positive_values inputs: Sequence = None activation: Optional[nn.Module] = nn.Softmax() decoder: Decoder = field(default_factory=ClassificationDecoder) module: nn.Module = Identity() metrics: Sequence[Tuple[str, TorchMetric]] = field( default_factory=get_default_classification_metrics) class_names: Optional = None top_k: Optional[int] = 5 def get_treelib_explainer(self) -> Callable: def classification_explainer( task_name: str, decoded: torch.Tensor, activated: torch.Tensor, logits: torch.Tensor, node_identifier: str) -> Tuple[Tree, Node]: tree = Tree() start_node = tree.create_node(task_name, node_identifier) for i, idx in enumerate(decoded[:self.top_k]): name = idx if self.class_names is None else self.class_names[ idx] description = f'{i}: {name} | activated: {activated[idx]:.4f}, logits: {logits[idx]:.4f}' tree.create_node(description, f'{node_identifier}.{idx}', parent=start_node) return tree, start_node return classification_explainer
class MultilabelClassificationTask(Task): """ Represents a classification task. Labels should be integers from 0 to N-1, where N is the number of classes * activation - `nn.Sigmoid()` * loss - `nn.CrossEntropyLoss()` """ name: str labels: Sequence loss: nn.Module = nn.BCEWithLogitsLoss(reduction='mean') per_sample_loss: nn.Module = ReducedPerSample( nn.BCEWithLogitsLoss(reduction='none'), torch.mean) available_func: Callable = positive_values inputs: Sequence = None activation: Optional[nn.Module] = nn.Sigmoid() decoder: Decoder = field(default_factory=MultilabelClassificationDecoder) module: nn.Module = Identity() metrics: Sequence[Tuple[str, TorchMetric]] = field( default_factory=get_default_multilabel_classification_metrics) class_names: Optional = None def get_treelib_explainer(self) -> Callable: def explainer(task_name: str, decoded: np.ndarray, activated: np.ndarray, logits: np.ndarray, node_identifier: str) -> Tuple[Tree, Node]: tree = Tree() start_node = tree.create_node(task_name, node_identifier) for i, val in enumerate(decoded): name = i if self.class_names is None else self.class_names[i] description = f'{i}: {name} | decoded: {decoded[i]}, ' \ f'activated: {activated[i]:.4f}, ' \ f'logits: {logits[i]:.4f}' tree.create_node(description, f'{node_identifier}.{i}', parent=start_node) return tree, start_node return explainer