def __call__( self, labels: torch.Tensor, logits: torch.Tensor, shapes: Optional[torch.Tensor] = None, ): labels, logits, shapes = detach_and_move_tensors( labels, logits, shapes, device=self.device, non_blocking=True, ) if self.multilabel: probas = torch.sigmoid(logits) else: probas = torch.softmax(logits, dim=1) if self.ignore_padding and logits.shape[0] > 1: if shapes is None or shapes.shape[0] != logits.shape[0]: raise ValueError( "Ignoring padding, thus shapes should be set" "and have the same size as the input" ) for idx in range(logits.shape[0]): shape = shapes[idx] sub_labels = ( cut_with_padding(labels[idx], shape, self.margin) .unsqueeze(0) .contiguous() ) sub_probas = ( cut_with_padding(probas[idx], shape, self.margin) .unsqeeze(0) .contiguous() ) self._update(sub_labels, sub_probas) else: self._update(labels, probas)
def compute_with_shapes(input_tensor, shapes, reduce=torch.mean, margin=0): res = torch.tensor(0.0).to(input_tensor.device) for idx in range(shapes.shape[0]): shape = shapes[idx] res += reduce(cut_with_padding(input_tensor[idx], shape, margin)) res = reduce(res) return res
def log_masks( self, masks: torch.Tensor, iteration: int, prefix: str, shapes: Optional[torch.Tensor] = None, ): if len(masks) == 1: self.log_mask(masks[0], iteration, prefix) return for idx in range(masks.shape[0]): mask = masks[idx] if self.ignore_padding: mask = cut_with_padding(mask, shapes[idx], margin=self.margin) self.log_mask(mask, iteration, self._join(prefix, str(idx)))
def log_images( self, images: torch.Tensor, iteration: int, prefix: str, shapes: Optional[torch.Tensor] = None, ): if len(images) == 1: self.log_image(images[0], iteration, prefix) return for idx in range(images.shape[0]): image = images[idx] if self.ignore_padding: image = cut_with_padding(image, shapes[idx], margin=self.margin) self.log_image(image, iteration, self._join(prefix, str(idx)))
def log_batch(self, batch: Dict[str, torch.Tensor], iteration: int, prefix: str = ""): batch_size = len(batch["input"]) indices_sel, _ = torch.sort( torch.randperm(batch_size)[:self.max_images_to_log]) shapes = batch["shapes"][indices_sel] if "shapes" in batch else None images = batch["input"][indices_sel] preds = None gts = None if "logits" in batch: logits = batch["logits"][indices_sel] if self.color_labels.multilabel: one_hot = (torch.sigmoid(logits) > self.probas_threshold).to( torch.long) preds = self._one_hot_to_indices(one_hot) else: preds = logits.argmax(dim=1) if "target" in batch: gts = batch["target"][indices_sel] if self.color_labels.multilabel: gts = self._one_hot_to_indices(gts) all_images = [] for idx in range(len(indices_sel)): shape = shapes[idx] image = images[idx] if self.ignore_padding: image = cut_with_padding(image, shape, self.margin) image = (image.permute(1, 2, 0) * 255).numpy().astype(np.uint8) mask_dict = {} if preds is not None: pred = preds[idx] if self.ignore_padding: pred = cut_with_padding(pred, shape, self.margin) pred = pred.numpy().astype(np.uint8) mask_dict["prediction"] = { "mask_data": pred, "class_labels": dict(enumerate(self.color_labels.log_labels)) if self.color_labels.log_labels else dict( enumerate([ str(x) for x in range(len(self.color_labels.colors)) ])), } if gts is not None: gt = gts[idx] if self.ignore_padding: gt = cut_with_padding(gt, shape, self.margin) gt = gt.numpy().astype(np.uint8) mask_dict["ground_truth"] = { "mask_data": gt, "class_labels": dict(enumerate(self.color_labels.log_labels)) if self.color_labels.log_labels else dict( enumerate([ str(x) for x in range(len(self.color_labels.colors)) ])), } if len(mask_dict) == 0: mask_dict = None all_images.append(self.wandb.Image(image, masks=mask_dict)) self.wandb.log({self._join(prefix, "images"): all_images}, commit=False, step=iteration)