예제 #1
0
 def validation_step(self, batch, batch_idx):
     if self.current_epoch < self.args.skip_first_n_eval:
         return None
     img, lbl = batch["image"], batch["label"]
     if self.args.hpus:
         img, lbl = img.to(torch.device("hpu"),
                           non_blocking=False), lbl.to(torch.device("hpu"),
                                                       non_blocking=False)
     pred = self.forward(img)
     loss = self.loss(pred, lbl)
     self.dice.update(pred, lbl[:, 0])
     mark_step(self.args.run_lazy_mode)
     return {"val_loss": loss}
예제 #2
0
 def get_train_data(self, batch):
     img, lbl = batch["image"], batch["label"]
     if self.args.dim == 2 and self.args.data2d_dim == 3:
         img, lbl = layout_2d(img, lbl)
     if self.args.hpus:
         img, lbl = img.to(torch.device("hpu"),
                           non_blocking=False), lbl.to(torch.device("hpu"),
                                                       non_blocking=False)
     if self.args.channels_last:
         if img.ndim == 4:
             img = img.contiguous(memory_format=torch.channels_last)
             lbl = lbl.contiguous(memory_format=torch.channels_last)
         elif img.ndim == 5:
             img = img.contiguous(memory_format=torch.channels_last_3d)
             lbl = lbl.contiguous(memory_format=torch.channels_last_3d)
         mark_step(self.args.run_lazy_mode)
     return img, lbl
예제 #3
0
    def test_step(self, batch, batch_idx):
        print("Start test")
        if self.args.exec_mode == "evaluate":
            return self.validation_step(batch, batch_idx)
        img = batch["image"]
        if self.args.hpus:
            img = img.to(torch.device("hpu"), non_blocking=False)
        if self.args.channels_last:
            if img.ndim == 4 or self.args.dim == 2:
                img = img.contiguous(memory_format=torch.channels_last)
            elif img.ndim == 5 and self.args.dim == 3:
                img = img.contiguous(memory_format=torch.channels_last_3d)
            mark_step(self.args.run_lazy_mode)

        pred = self.forward(img)
        mark_step(self.args.run_lazy_mode)

        if self.args.save_preds:
            meta = batch["meta"][0].cpu().detach().numpy()
            original_shape = meta[2]
            min_d, max_d = meta[0, 0], meta[1, 0]
            min_h, max_h = meta[0, 1], meta[1, 1]
            min_w, max_w = meta[0, 2], meta[1, 2]

            final_pred = torch.zeros((1, pred.shape[1], *original_shape),
                                     device=img.device)
            final_pred[:, :, min_d:max_d, min_h:max_h, min_w:max_w] = pred
            final_pred = nn.functional.softmax(final_pred, dim=1)
            final_pred = final_pred.squeeze(0).cpu().detach().numpy()

            if not all(original_shape == final_pred.shape[1:]):
                class_ = final_pred.shape[0]
                resized_pred = np.zeros((class_, *original_shape))
                for i in range(class_):
                    resized_pred[i] = resize(final_pred[i],
                                             original_shape,
                                             order=3,
                                             mode="edge",
                                             cval=0,
                                             clip=True,
                                             anti_aliasing=False)
                final_pred = resized_pred

            self.save_mask(final_pred)
예제 #4
0
 def inference2d(self, image):
     batch_modulo = image.shape[2] % self.args.val_batch_size
     if batch_modulo != 0:
         batch_pad = self.args.val_batch_size - batch_modulo
         image = nn.ConstantPad3d((0, 0, 0, 0, batch_pad, 0), 0)(image)
         mark_step(self.args.run_lazy_mode)
     image = torch.transpose(image.squeeze(0), 0, 1)
     preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:])
     if self.args.hpus:
         preds = None
         for start in range(0,
                            image.shape[0] - self.args.val_batch_size + 1,
                            self.args.val_batch_size):
             end = start + self.args.val_batch_size
             pred = self.model(image[start:end])
             preds = pred if preds == None else torch.cat(
                 (preds, pred), dim=0)
             mark_step(self.args.run_lazy_mode)
         if batch_modulo != 0:
             preds = preds[batch_pad:]
             mark_step(self.args.run_lazy_mode)
     else:
         preds = torch.zeros(preds_shape,
                             dtype=image.dtype,
                             device=image.device)
         for start in range(0,
                            image.shape[0] - self.args.val_batch_size + 1,
                            self.args.val_batch_size):
             end = start + self.args.val_batch_size
             pred = self.model(image[start:end])
             preds[start:end] = pred.data
         if batch_modulo != 0:
             preds = preds[batch_pad:]
     return torch.transpose(preds, 0, 1).unsqueeze(0)
예제 #5
0
 def on_after_backward(self):
     mark_step(self.args.run_lazy_mode)
예제 #6
0
 def on_before_zero_grad(self, optimizer):
     mark_step(self.args.run_lazy_mode)
예제 #7
0
 def training_step(self, batch, batch_idx):
     img, lbl = self.get_train_data(batch)
     pred = self.model(img)
     loss = self.compute_loss(pred, lbl)
     mark_step(self.args.run_lazy_mode)
     return loss
예제 #8
0
 def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                    optimizer_closure, on_tpu, using_native_amp,
                    using_lbfgs):
     optimizer.step(closure=optimizer_closure)
     mark_step(self.args.run_lazy_mode)