Ejemplo n.º 1
0
def on_batch_end(self, last_output: Tensor, last_target: Tensor, **kwargs):
    """TODO"""
    preds = last_output.argmax(1).view(-1).cpu()
    targs = last_target.view(-1).cpu()
    if self.n_classes == 0:
        self.n_classes = last_output.shape[1]
        self.x = torch.arange(0, self.n_classes)
    cm = ((preds == self.x[:, None]) & (targs == self.x[:, None, None])).sum(
        dim=2, dtype=torch.float32)
    if self.cm is None:
        self.cm = cm
    else:
        self.cm += cm
 def __init__(self, px:Tensor):
     "Create from raw tensor image data `px`."
     self._px = px.type(torch.FloatTensor)
     self._logit_px=None
     self._flow=None
     self._affine_mat=None
     self.sample_kwargs = {}
Ejemplo n.º 3
0
    def split_parts_pred(self, t: fv.Tensor):
        """
        Split parts prediction by objects.
        Args:
            t: Tensor of shape: (bs, n_parts, h , w)

        Returns: List of length n_obj_with_parts where each item is tensor of shape (bs, n_parts_i, h, w)

        """
        return t.split(self.sections, dim=1)
def dice(
    best_thr0,
    noise_th,
    input: Tensor,
    targs: Tensor,
    iou: bool = False,
    eps: float = 1e-8,
) -> Rank0Tensor:
    n = targs.shape[0]
    input = torch.softmax(input, dim=1)[:, 1, ...].view(n, -1)
    input = (input > best_thr0).long()
    input[input.sum(-1) < noise_th, ...] = 0.0
    # input = input.argmax(dim=1).view(n,-1)
    targs = targs.view(n, -1)
    intersect = (input * targs).sum(-1).float()
    union = (input + targs).sum(-1).float()
    if not iou:
        return ((2.0 * intersect + eps) / (union + eps)).mean()
    else:
        return ((intersect + eps) / (union - intersect + eps)).mean()
Ejemplo n.º 5
0
def get_learn(data, model, name, weighted, cut):
    """TODO"""
    metrics = get_metrics()
    learn = unet_learner(
        data,
        model,
        split_on=_resnet_split,
        cut=cut,
        metrics=metrics,
        path="models",
        model_dir=name,
        wd=1e-2,
    )
    if weighted:
        weights = get_loss_weights(data, learn)
        learn.loss_fn = CrossEntropyFlat(weight=Tensor(weights).cuda(),
                                         ignore_index=0)
    else:
        learn.loss_fn = CrossEntropyFlat(ignore_index=0)
    learn = learn.to_fp16()
    return learn
Ejemplo n.º 6
0
    def split_parts_gt(self,
                       obj: fv.Tensor,
                       part: fv.Tensor,
                       mark_in_obj=True):
        """
        Splits parts gt by objects.
        Args:
            obj: Tensor of shape: (bs, h , w)
            part: Tensor of shape: (bs, h , w)
            mark_in_obj: bool, non-object pixels will be marked as -1

        Returns: Tensor of shape (n_obj_with_parts, bs, h, w)

        """

        present_obj = obj.unique().cpu().tolist()
        present_obj = [o for o in present_obj if o in self.tree]
        classes = torch.tensor(list(self.obj_with_parts), device=obj.device)
        obj_masks = obj == classes[:, None, None, None]
        parts_inside_obj = part[None] * obj_masks
        gt = torch.full_like(parts_inside_obj, -1 if mark_in_obj else 0)
        for o in present_obj:
            obj_parts = self.tree[o]
            i = self.obj2idx[o]
            inside_obj_i = parts_inside_obj[i]
            for part_idx, part in enumerate(obj_parts[1:], start=1):
                part_mask = inside_obj_i == part
                gt[i][part_mask] = part_idx

        if not mark_in_obj:
            return gt

        # if an object has parts then label background (non-part) pixels inside the object with 0
        is_part = gt > 0
        has_parts = is_part.flatten(start_dim=2).any(dim=2, keepdim=True)[...,
                                                                          None]
        bg_inside_obj = has_parts * obj_masks * (~is_part)
        gt[bg_inside_obj] = 0
        return gt
Ejemplo n.º 7
0
def niiimage2np(image: fvision.Tensor) -> np.ndarray:
    "Convert from torch style `image` to numpy/matplotlib style."
    res = image.cpu().permute(1, 2, 3, 0).numpy()
    return res[..., 0] if res.shape[2] == 1 else res
Ejemplo n.º 8
0
 def forward(self, input: Tensor, target: Tensor) -> Rank0Tensor:
     return super().forward(input.view(-1), target.view(-1))
Ejemplo n.º 9
0
 def analyze_pred(self, pred: Tensor):
     pred = output_to_scaled_pred(pred[None])[0]
     pred.clamp_(-1, 1)
     visibility = pred.new_ones(pred.shape[:-1])
     pred = torch.cat((pred, visibility[..., None]), dim=-1)
     return pred