Beispiel #1
0
class BinaryHardcodedTask(Task):
    name: str
    labels: Any
    loss: nn.Module = None
    per_sample_loss: nn.Module = None
    available_func: Callable = positive_values
    inputs: Any = None
    activation: Optional[nn.Module] = None
    decoder: Decoder = None
    module: nn.Module = Identity()
    metrics: Tuple[str, TorchMetric] = ()
Beispiel #2
0
class MultilabelClassificationTask(Task):
    """
    Represents a classification task. Labels should be integers from 0 to N-1, where N is the number of classes
    * activation - `nn.Softmax(dim=-1)`
    * loss - `nn.CrossEntropyLoss()`
    """
    name: str
    labels: Any
    loss: nn.Module = nn.BCEWithLogitsLoss(reduction='mean')
    per_sample_loss: nn.Module = ReducedPerSample(
        nn.BCEWithLogitsLoss(reduction='none'), torch.mean)
    available_func: Callable = positive_values
    inputs: Any = None
    activation: Optional[nn.Module] = nn.Sigmoid()
    decoder: Decoder = field(default_factory=MultilabelClassificationDecoder)
    module: nn.Module = Identity()
    metrics: Tuple[str, TorchMetric] = field(
        default_factory=get_default_multilabel_classification_metrics)
Beispiel #3
0
class BinaryClassificationTask(Task):
    """
    Represents a normal binary classification task. Labels should be between 0 and 1.
    * activation - `nn.Sigmoid()`
    * loss - ``nn.BCEWithLogitsLoss()`
    """
    name: str
    labels: Any
    loss: nn.Module = nn.BCEWithLogitsLoss(reduction='mean')
    per_sample_loss: nn.Module = ReducedPerSample(
        nn.BCEWithLogitsLoss(reduction='none'), reduction=torch.mean)
    available_func: Callable = positive_values
    inputs: Any = None
    activation: Optional[nn.Module] = nn.Sigmoid()
    decoder: Decoder = field(default_factory=BinaryDecoder)
    module: nn.Module = Identity()
    metrics: Tuple[str, TorchMetric] = field(
        default_factory=get_default_binary_metrics)
Beispiel #4
0
class ClassificationTask(Task):
    """
    Represents a classification task. Labels should be integers from 0 to N-1, where N is the number of classes
    * activation - `nn.Softmax()`
    * loss - `nn.CrossEntropyLoss()`
    """
    name: str
    labels: Sequence
    loss: nn.Module = nn.CrossEntropyLoss(reduction='mean')
    per_sample_loss: nn.Module = ReducedPerSample(
        nn.CrossEntropyLoss(reduction='none'), torch.mean)
    available_func: Callable = positive_values
    inputs: Sequence = None
    activation: Optional[nn.Module] = nn.Softmax()
    decoder: Decoder = field(default_factory=ClassificationDecoder)
    module: nn.Module = Identity()
    metrics: Sequence[Tuple[str, TorchMetric]] = field(
        default_factory=get_default_classification_metrics)

    class_names: Optional = None
    top_k: Optional[int] = 5

    def get_treelib_explainer(self) -> Callable:
        def classification_explainer(
                task_name: str, decoded: torch.Tensor, activated: torch.Tensor,
                logits: torch.Tensor,
                node_identifier: str) -> Tuple[Tree, Node]:
            tree = Tree()
            start_node = tree.create_node(task_name, node_identifier)
            for i, idx in enumerate(decoded[:self.top_k]):
                name = idx if self.class_names is None else self.class_names[
                    idx]
                description = f'{i}: {name} | activated: {activated[idx]:.4f}, logits: {logits[idx]:.4f}'
                tree.create_node(description,
                                 f'{node_identifier}.{idx}',
                                 parent=start_node)
            return tree, start_node

        return classification_explainer
Beispiel #5
0
class MultilabelClassificationTask(Task):
    """
    Represents a classification task. Labels should be integers from 0 to N-1, where N is the number of classes
    * activation - `nn.Sigmoid()`
    * loss - `nn.CrossEntropyLoss()`
    """
    name: str
    labels: Sequence
    loss: nn.Module = nn.BCEWithLogitsLoss(reduction='mean')
    per_sample_loss: nn.Module = ReducedPerSample(
        nn.BCEWithLogitsLoss(reduction='none'), torch.mean)
    available_func: Callable = positive_values
    inputs: Sequence = None
    activation: Optional[nn.Module] = nn.Sigmoid()
    decoder: Decoder = field(default_factory=MultilabelClassificationDecoder)
    module: nn.Module = Identity()
    metrics: Sequence[Tuple[str, TorchMetric]] = field(
        default_factory=get_default_multilabel_classification_metrics)

    class_names: Optional = None

    def get_treelib_explainer(self) -> Callable:
        def explainer(task_name: str, decoded: np.ndarray,
                      activated: np.ndarray, logits: np.ndarray,
                      node_identifier: str) -> Tuple[Tree, Node]:
            tree = Tree()
            start_node = tree.create_node(task_name, node_identifier)
            for i, val in enumerate(decoded):
                name = i if self.class_names is None else self.class_names[i]
                description = f'{i}: {name} | decoded: {decoded[i]}, ' \
                              f'activated: {activated[i]:.4f}, ' \
                              f'logits: {logits[i]:.4f}'
                tree.create_node(description,
                                 f'{node_identifier}.{i}',
                                 parent=start_node)
            return tree, start_node

        return explainer