Exemplo n.º 1
0
    def __call__(
        self,
        labels: torch.Tensor,
        logits: torch.Tensor,
        shapes: Optional[torch.Tensor] = None,
    ):
        labels, logits, shapes = detach_and_move_tensors(
            labels, logits, shapes, device=self.device, non_blocking=True,
        )
        if self.multilabel:
            probas = torch.sigmoid(logits)
        else:
            probas = torch.softmax(logits, dim=1)
        if self.ignore_padding and logits.shape[0] > 1:
            if shapes is None or shapes.shape[0] != logits.shape[0]:
                raise ValueError(
                    "Ignoring padding, thus shapes should be set"
                    "and have the same size as the input"
                )
            for idx in range(logits.shape[0]):
                shape = shapes[idx]
                sub_labels = (
                    cut_with_padding(labels[idx], shape, self.margin)
                    .unsqueeze(0)
                    .contiguous()
                )
                sub_probas = (
                    cut_with_padding(probas[idx], shape, self.margin)
                    .unsqeeze(0)
                    .contiguous()
                )

                self._update(sub_labels, sub_probas)
        else:
            self._update(labels, probas)
Exemplo n.º 2
0
def compute_with_shapes(input_tensor, shapes, reduce=torch.mean, margin=0):
    res = torch.tensor(0.0).to(input_tensor.device)
    for idx in range(shapes.shape[0]):
        shape = shapes[idx]
        res += reduce(cut_with_padding(input_tensor[idx], shape, margin))
    res = reduce(res)
    return res
Exemplo n.º 3
0
 def log_masks(
     self,
     masks: torch.Tensor,
     iteration: int,
     prefix: str,
     shapes: Optional[torch.Tensor] = None,
 ):
     if len(masks) == 1:
         self.log_mask(masks[0], iteration, prefix)
         return
     for idx in range(masks.shape[0]):
         mask = masks[idx]
         if self.ignore_padding:
             mask = cut_with_padding(mask, shapes[idx], margin=self.margin)
         self.log_mask(mask, iteration, self._join(prefix, str(idx)))
Exemplo n.º 4
0
 def log_images(
     self,
     images: torch.Tensor,
     iteration: int,
     prefix: str,
     shapes: Optional[torch.Tensor] = None,
 ):
     if len(images) == 1:
         self.log_image(images[0], iteration, prefix)
         return
     for idx in range(images.shape[0]):
         image = images[idx]
         if self.ignore_padding:
             image = cut_with_padding(image,
                                      shapes[idx],
                                      margin=self.margin)
         self.log_image(image, iteration, self._join(prefix, str(idx)))
Exemplo n.º 5
0
    def log_batch(self,
                  batch: Dict[str, torch.Tensor],
                  iteration: int,
                  prefix: str = ""):
        batch_size = len(batch["input"])
        indices_sel, _ = torch.sort(
            torch.randperm(batch_size)[:self.max_images_to_log])

        shapes = batch["shapes"][indices_sel] if "shapes" in batch else None
        images = batch["input"][indices_sel]
        preds = None
        gts = None

        if "logits" in batch:
            logits = batch["logits"][indices_sel]

            if self.color_labels.multilabel:
                one_hot = (torch.sigmoid(logits) > self.probas_threshold).to(
                    torch.long)
                preds = self._one_hot_to_indices(one_hot)
            else:
                preds = logits.argmax(dim=1)

        if "target" in batch:
            gts = batch["target"][indices_sel]

            if self.color_labels.multilabel:
                gts = self._one_hot_to_indices(gts)

        all_images = []
        for idx in range(len(indices_sel)):
            shape = shapes[idx]
            image = images[idx]
            if self.ignore_padding:
                image = cut_with_padding(image, shape, self.margin)
            image = (image.permute(1, 2, 0) * 255).numpy().astype(np.uint8)

            mask_dict = {}
            if preds is not None:
                pred = preds[idx]
                if self.ignore_padding:
                    pred = cut_with_padding(pred, shape, self.margin)
                pred = pred.numpy().astype(np.uint8)
                mask_dict["prediction"] = {
                    "mask_data":
                    pred,
                    "class_labels":
                    dict(enumerate(self.color_labels.log_labels))
                    if self.color_labels.log_labels else dict(
                        enumerate([
                            str(x)
                            for x in range(len(self.color_labels.colors))
                        ])),
                }

            if gts is not None:
                gt = gts[idx]
                if self.ignore_padding:
                    gt = cut_with_padding(gt, shape, self.margin)
                gt = gt.numpy().astype(np.uint8)
                mask_dict["ground_truth"] = {
                    "mask_data":
                    gt,
                    "class_labels":
                    dict(enumerate(self.color_labels.log_labels))
                    if self.color_labels.log_labels else dict(
                        enumerate([
                            str(x)
                            for x in range(len(self.color_labels.colors))
                        ])),
                }
            if len(mask_dict) == 0:
                mask_dict = None
            all_images.append(self.wandb.Image(image, masks=mask_dict))
        self.wandb.log({self._join(prefix, "images"): all_images},
                       commit=False,
                       step=iteration)