Beispiel #1
0
def get_dllogger(results):
    return Logger(
        backends=[
            JSONStreamBackend(Verbosity.VERBOSE, os.path.join(results, "logs.json")),
            StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: f"Epoch: {step} "),
        ]
    )
Beispiel #2
0
 def __init__(self, args):
     super(Model, self).__init__()
     self.save_hyperparameters()
     self.args = args
     self.f1_score = F1(args)
     self.model = UNetLoc(args) if args.type == "pre" else get_dmg_unet(
         args)
     self.loss = Loss(args)
     self.best_f1 = torch.tensor(0)
     self.best_epoch = 0
     self.tta_flips = [[2], [3], [2, 3]]
     self.lr = args.lr
     self.n_class = 2 if self.args.type == "pre" else 5
     self.softmax = nn.Softmax(dim=1)
     self.test_idx = 0
     self.dllogger = Logger(backends=[
         JSONStreamBackend(
             Verbosity.VERBOSE,
             os.path.join(args.results, f"{args.logname}.json")),
         StdOutBackend(Verbosity.VERBOSE,
                       step_format=lambda step: f"Epoch: {step} "),
     ])
Beispiel #3
0
def setup_logger(config):
    log_path = config.get("log_path", os.getcwd())
    if is_main_process():
        backends = [
            TensorBoardBackend(verbosity=dllogger.Verbosity.VERBOSE,
                               log_dir=log_path),
            JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                              filename=os.path.join(log_path, "log.json")),
            AggregatorBackend(verbosity=dllogger.Verbosity.VERBOSE,
                              agg_dict={"loss": AverageMeter}),
            StdOutBackend(
                verbosity=dllogger.Verbosity.DEFAULT,
                step_format=empty_step_format,
                metric_format=no_string_metric_format,
                prefix_format=empty_prefix_format,
            ),
        ]

        logger = Logger(backends=backends)
    else:
        logger = Logger(backends=[])
    container_setup_info = get_framework_env_vars()
    logger.log(step="PARAMETER",
               data=container_setup_info,
               verbosity=dllogger.Verbosity.DEFAULT)

    logger.metadata("loss", {
        "unit": "nat",
        "GOAL": "MINIMIZE",
        "STAGE": "TRAIN"
    })
    logger.metadata("val_loss", {
        "unit": "nat",
        "GOAL": "MINIMIZE",
        "STAGE": "VAL"
    })
    return logger
    def __init__(self, args):
        super(NNUnet, self).__init__()
        self.args = args
        self.save_hyperparameters()
        self.build_nnunet()
        self.loss = Loss()
        self.dice = Dice(self.n_class)
        self.best_sum = 0
        self.eval_dice = 0
        self.best_sum_epoch = 0
        self.best_dice = self.n_class * [0]
        self.best_epoch = self.n_class * [0]
        self.best_sum_dice = self.n_class * [0]
        self.learning_rate = args.learning_rate
        if self.args.exec_mode in ["train", "evaluate"]:
            self.dllogger = Logger(backends=[
                JSONStreamBackend(Verbosity.VERBOSE,
                                  os.path.join(args.results, "logs.json")),
                StdOutBackend(Verbosity.VERBOSE,
                              step_format=lambda step: f"Epoch: {step} "),
            ])

        self.tta_flips = ([[2], [3], [2, 3]] if self.args.dim == 2 else
                          [[2], [3], [4], [2, 3], [2, 4], [3, 4], [2, 3, 4]])
Beispiel #5
0
def log(logname, dice, results="/results"):
    dllogger = Logger(backends=[
        JSONStreamBackend(Verbosity.VERBOSE, os.path.join(results, logname)),
        StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: ""),
    ])
    metrics = {}
    metrics.update({"Mean dice": round(dice.mean().item(), 2)})
    metrics.update({f"L{j+1}": round(m.item(), 2) for j, m in enumerate(dice)})
    dllogger.log(step=(), data=metrics)
    dllogger.flush()
Beispiel #6
0
def log(logname, dice, epoch=None, dice_tta=None):
    dllogger = Logger(backends=[
        JSONStreamBackend(Verbosity.VERBOSE, os.path.join(
            args.results, logname)),
        StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: ""),
    ])
    metrics = {}
    if epoch is not None:
        metrics.update({"Epoch": epoch})
    metrics.update({"Mean dice": round(dice.mean().item(), 2)})
    if dice_tta is not None:
        metrics.update({"Mean TTA dice": round(dice_tta.mean().item(), 2)})
    metrics.update({f"L{j+1}": round(m.item(), 2) for j, m in enumerate(dice)})
    if dice_tta is not None:
        metrics.update({
            f"TTA_L{j+1}": round(m.item(), 2)
            for j, m in enumerate(dice_tta)
        })
    dllogger.log(step=(), data=metrics)
    dllogger.flush()
 def __init__(self, log_path="bert_dllog.json"):
     self.logger = Logger([
         StdOutBackend(Verbosity.DEFAULT, step_format=self.format_step),
         JSONStreamBackend(Verbosity.VERBOSE, log_path),
     ])
     self.logger.metadata("mlm_loss", {
         "format": ":.4f",
         "GOAL": "MINIMIZE",
         "STAGE": "TRAIN"
     })
     self.logger.metadata("nsp_loss", {
         "format": ":.4f",
         "GOAL": "MINIMIZE",
         "STAGE": "TRAIN"
     })
     self.logger.metadata("avg_loss_step", {
         "format": ":.4f",
         "GOAL": "MINIMIZE",
         "STAGE": "TRAIN"
     })
     self.logger.metadata("total_loss", {
         "format": ":.4f",
         "GOAL": "MINIMIZE",
         "STAGE": "TRAIN"
     })
     self.logger.metadata("loss", {
         "format": ":.4f",
         "GOAL": "MINIMIZE",
         "STAGE": "TRAIN"
     })
     self.logger.metadata("f1", {
         "format": ":.4f",
         "GOAL": "MINIMIZE",
         "STAGE": "VAL"
     })
     self.logger.metadata("precision", {
         "format": ":.4f",
         "GOAL": "MINIMIZE",
         "STAGE": "VAL"
     })
     self.logger.metadata("recall", {
         "format": ":.4f",
         "GOAL": "MINIMIZE",
         "STAGE": "VAL"
     })
     self.logger.metadata("mcc", {
         "format": ":.4f",
         "GOAL": "MINIMIZE",
         "STAGE": "VAL"
     })
     self.logger.metadata("exact_match", {
         "format": ":.4f",
         "GOAL": "MINIMIZE",
         "STAGE": "VAL"
     })
     self.logger.metadata(
         "throughput_train",
         {
             "unit": "seq/s",
             "format": ":.3f",
             "GOAL": "MAXIMIZE",
             "STAGE": "TRAIN"
         },
     )
     self.logger.metadata(
         "throughput_inf",
         {
             "unit": "seq/s",
             "format": ":.3f",
             "GOAL": "MAXIMIZE",
             "STAGE": "VAL"
         },
     )
class dllogger_class():
    def format_step(self, step):
        if isinstance(step, str):
            return step
        elif isinstance(step, int):
            return "Iteration: {} ".format(step)
        elif len(step) > 0:
            return "Iteration: {} ".format(step[0])
        else:
            return ""

    def __init__(self, log_path="bert_dllog.json"):
        self.logger = Logger([
            StdOutBackend(Verbosity.DEFAULT, step_format=self.format_step),
            JSONStreamBackend(Verbosity.VERBOSE, log_path),
        ])
        self.logger.metadata("mlm_loss", {
            "format": ":.4f",
            "GOAL": "MINIMIZE",
            "STAGE": "TRAIN"
        })
        self.logger.metadata("nsp_loss", {
            "format": ":.4f",
            "GOAL": "MINIMIZE",
            "STAGE": "TRAIN"
        })
        self.logger.metadata("avg_loss_step", {
            "format": ":.4f",
            "GOAL": "MINIMIZE",
            "STAGE": "TRAIN"
        })
        self.logger.metadata("total_loss", {
            "format": ":.4f",
            "GOAL": "MINIMIZE",
            "STAGE": "TRAIN"
        })
        self.logger.metadata("loss", {
            "format": ":.4f",
            "GOAL": "MINIMIZE",
            "STAGE": "TRAIN"
        })
        self.logger.metadata("f1", {
            "format": ":.4f",
            "GOAL": "MINIMIZE",
            "STAGE": "VAL"
        })
        self.logger.metadata("precision", {
            "format": ":.4f",
            "GOAL": "MINIMIZE",
            "STAGE": "VAL"
        })
        self.logger.metadata("recall", {
            "format": ":.4f",
            "GOAL": "MINIMIZE",
            "STAGE": "VAL"
        })
        self.logger.metadata("mcc", {
            "format": ":.4f",
            "GOAL": "MINIMIZE",
            "STAGE": "VAL"
        })
        self.logger.metadata("exact_match", {
            "format": ":.4f",
            "GOAL": "MINIMIZE",
            "STAGE": "VAL"
        })
        self.logger.metadata(
            "throughput_train",
            {
                "unit": "seq/s",
                "format": ":.3f",
                "GOAL": "MAXIMIZE",
                "STAGE": "TRAIN"
            },
        )
        self.logger.metadata(
            "throughput_inf",
            {
                "unit": "seq/s",
                "format": ":.3f",
                "GOAL": "MAXIMIZE",
                "STAGE": "VAL"
            },
        )
Beispiel #9
0
class Model(pl.LightningModule):
    def __init__(self, args):
        super(Model, self).__init__()
        self.save_hyperparameters()
        self.args = args
        self.f1_score = F1(args)
        self.model = UNetLoc(args) if args.type == "pre" else get_dmg_unet(
            args)
        self.loss = Loss(args)
        self.best_f1 = torch.tensor(0)
        self.best_epoch = 0
        self.tta_flips = [[2], [3], [2, 3]]
        self.lr = args.lr
        self.n_class = 2 if self.args.type == "pre" else 5
        self.softmax = nn.Softmax(dim=1)
        self.test_idx = 0
        self.dllogger = Logger(backends=[
            JSONStreamBackend(
                Verbosity.VERBOSE,
                os.path.join(args.results, f"{args.logname}.json")),
            StdOutBackend(Verbosity.VERBOSE,
                          step_format=lambda step: f"Epoch: {step} "),
        ])

    def forward(self, img):
        pred = self.model(img)
        if self.args.tta:
            for flip_idx in self.tta_flips:
                pred += self.flip(self.model(self.flip(img, flip_idx)),
                                  flip_idx)
            pred /= len(self.tta_flips) + 1
        return pred

    def training_step(self, batch, _):
        img, lbl = batch["image"], batch["mask"]
        pred = self.model(img)
        loss = self.compute_loss(pred, lbl)
        return loss

    def validation_step(self, batch, _):
        img, lbl = batch["image"], batch["mask"]
        pred = self.forward(img)
        loss = self.loss(pred, lbl)
        self.f1_score.update(pred, lbl)
        return {"val_loss": loss}

    def test_step(self, batch, batch_idx):
        img, lbl = batch["image"], batch["mask"]
        pred = self.forward(img)
        self.f1_score.update(pred, lbl)
        self.save(pred, lbl)

    def compute_loss(self, preds, label):
        if self.args.deep_supervision:
            loss = self.loss(preds[0], label)
            for i, pred in enumerate(preds[1:]):
                downsampled_label = torch.nn.functional.interpolate(
                    label.unsqueeze(1), pred.shape[2:])
                loss += 0.5**(i + 1) * self.loss(pred,
                                                 downsampled_label.squeeze(1))
            c_norm = 1 / (2 - 2**(-len(preds)))
            return c_norm * loss
        return self.loss(preds, label)

    @staticmethod
    def metric_mean(name, outputs):
        return torch.stack([out[name] for out in outputs]).mean(dim=0)

    @staticmethod
    def update_damage_scores(metrics, dmgs_f1):
        if dmgs_f1 is not None:
            for i in range(4):
                metrics.update({f"D{i+1}": round(dmgs_f1[i].item(), 3)})

    def on_validation_epoch_start(self):
        self.f1_score.reset()

    def on_test_epoch_start(self):
        self.f1_score.reset()

    def validation_epoch_end(self, outputs):
        loss = self.metric_mean("val_loss", outputs)
        f1_score, dmgs_f1 = self.f1_score.compute()
        self.f1_score.reset()

        if f1_score >= self.best_f1:
            self.best_f1 = f1_score
            self.best_epoch = self.current_epoch

        if int(os.getenv("LOCAL_RANK", "0")) == 0:
            metrics = {
                "f1": round(f1_score.item(), 3),
                "val_loss": round(loss.item(), 3),
                "top_f1": round(self.best_f1.item(), 3),
            }
            self.update_damage_scores(metrics, dmgs_f1)
            self.dllogger.log(step=self.current_epoch, data=metrics)
            self.dllogger.flush()

        self.log("f1_score", f1_score.cpu())
        self.log("val_loss", loss.cpu())

    def test_epoch_end(self, _):
        f1_score, dmgs_f1 = self.f1_score.compute()
        self.f1_score.reset()
        if int(os.getenv("LOCAL_RANK", "0")) == 0:
            metrics = {"f1": round(f1_score.item(), 3)}
            self.update_damage_scores(metrics, dmgs_f1)
            self.dllogger.log(step=(), data=metrics)
            self.dllogger.flush()

    def save(self, preds, targets):
        if self.args.type == "pre":
            probs = torch.sigmoid(preds[:, 1])
        else:
            if self.args.loss_str == "coral":
                probs = torch.sum(torch.sigmoid(preds) > 0.5, dim=1) + 1
            elif self.args.loss_str == "mse":
                probs = torch.round(F.relu(preds[:, 0], inplace=True)) + 1
            else:
                probs = self.softmax(preds)

        probs = probs.cpu().detach().numpy()
        targets = targets.cpu().detach().numpy().astype(np.uint8)
        for prob, target in zip(probs, targets):
            task = "localization" if self.args.type == "pre" else "damage"
            fname = os.path.join(self.args.results, "probs",
                                 f"test_{task}_{self.test_idx:05d}")
            self.test_idx += 1
            np.save(fname, prob)
            Image.fromarray(target).save(
                fname.replace("probs", "targets") + "_target.png")

    @staticmethod
    def flip(data, axis):
        return torch.flip(data, dims=axis)

    def configure_optimizers(self):
        optimizer = {
            "sgd":
            FusedSGD(self.parameters(),
                     lr=self.lr,
                     momentum=self.args.momentum),
            "adam":
            FusedAdam(self.parameters(),
                      lr=self.lr,
                      weight_decay=self.args.weight_decay),
            "adamw":
            torch.optim.AdamW(self.parameters(),
                              lr=self.lr,
                              weight_decay=self.args.weight_decay),
            "radam":
            RAdam(self.parameters(),
                  lr=self.lr,
                  weight_decay=self.args.weight_decay),
            "adabelief":
            AdaBelief(self.parameters(),
                      lr=self.lr,
                      weight_decay=self.args.weight_decay),
            "adabound":
            AdaBound(self.parameters(),
                     lr=self.lr,
                     weight_decay=self.args.weight_decay),
            "adamp":
            AdamP(self.parameters(),
                  lr=self.lr,
                  weight_decay=self.args.weight_decay),
            "novograd":
            FusedNovoGrad(self.parameters(),
                          lr=self.lr,
                          weight_decay=self.args.weight_decay),
        }[self.args.optimizer.lower()]

        if not self.args.use_scheduler:
            return optimizer

        scheduler = {
            "scheduler":
            NoamLR(
                optimizer=optimizer,
                warmup_epochs=self.args.warmup,
                total_epochs=self.args.epochs,
                steps_per_epoch=len(self.train_dataloader()) // self.args.gpus,
                init_lr=self.args.init_lr,
                max_lr=self.args.lr,
                final_lr=self.args.final_lr,
            ),
            "interval":
            "step",
            "frequency":
            1,
        }

        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        arg = parser.add_argument
        arg(
            "--optimizer",
            type=str,
            default="adamw",
            choices=[
                "sgd", "adam", "adamw", "radam", "adabelief", "adabound",
                "adamp", "novograd"
            ],
        )
        arg(
            "--dmg_model",
            type=str,
            default="siamese",
            choices=[
                "siamese", "siameseEnc", "fused", "fusedEnc", "parallel",
                "parallelEnc", "diff", "cat"
            ],
            help="U-Net variant for damage assessment task",
        )
        arg(
            "--encoder",
            type=str,
            default="resnest200",
            choices=[
                "resnest50", "resnest101", "resnest200", "resnest269",
                "resnet50", "resnet101", "resnet152"
            ],
            help="U-Net encoder",
        )
        arg(
            "--loss_str",
            type=str,
            default="focal+dice",
            help=
            "Combination of: dice, focal, ce, ohem, mse, coral, e.g focal+dice creates the loss function as sum of focal and dice",
        )
        arg("--use_scheduler",
            action="store_true",
            help="Enable Noam learning rate scheduler")
        arg("--warmup",
            type=int,
            default=1,
            help="Warmup epochs for Noam learning rate scheduler")
        arg("--init_lr",
            type=float,
            default=1e-4,
            help="Initial learning rate for Noam scheduler")
        arg("--final_lr",
            type=float,
            default=1e-4,
            help="Final learning rate for Noam scheduler")
        arg("--lr",
            type=float,
            default=3e-4,
            help="Learning rate, or a target learning rate for Noam scheduler")
        arg("--weight_decay",
            type=float,
            default=0,
            help="Weight decay (L2 penalty)")
        arg("--momentum",
            type=float,
            default=0.9,
            help="Momentum for SGD optimizer")
        arg(
            "--dilation",
            type=int,
            choices=[1, 2, 4],
            default=1,
            help=
            "Dilation rate for a encoder, e.g dilation=2 uses dilation instead of stride in the last encoder block",
        )
        arg("--tta", action="store_true", help="Enable test time augmentation")
        arg("--ppm", action="store_true", help="Use pyramid pooling module")
        arg("--aspp",
            action="store_true",
            help="Use atrous spatial pyramid pooling")
        arg("--no_skip",
            action="store_true",
            help="Disable skip connections in UNet")
        arg("--deep_supervision",
            action="store_true",
            help="Enable deep supervision")
        arg("--attention",
            action="store_true",
            help="Enable attention module at the decoder")
        arg("--autoaugment",
            action="store_true",
            help="Use imageNet autoaugment pipeline")
        arg("--interpolate",
            action="store_true",
            help="Interpolate feature map from encoder without a decoder")
        arg("--dec_interp",
            action="store_true",
            help=
            "Use interpolation instead of transposed convolution in a decoder")
        return parser
class NNUnet(pl.LightningModule):
    def __init__(self, args):
        super(NNUnet, self).__init__()
        self.args = args
        self.save_hyperparameters()
        self.build_nnunet()
        self.loss = Loss()
        self.dice = Dice(self.n_class)
        self.best_sum = 0
        self.eval_dice = 0
        self.best_sum_epoch = 0
        self.best_dice = self.n_class * [0]
        self.best_epoch = self.n_class * [0]
        self.best_sum_dice = self.n_class * [0]
        self.learning_rate = args.learning_rate
        if self.args.exec_mode in ["train", "evaluate"]:
            self.dllogger = Logger(backends=[
                JSONStreamBackend(Verbosity.VERBOSE,
                                  os.path.join(args.results, "logs.json")),
                StdOutBackend(Verbosity.VERBOSE,
                              step_format=lambda step: f"Epoch: {step} "),
            ])

        self.tta_flips = ([[2], [3], [2, 3]] if self.args.dim == 2 else
                          [[2], [3], [4], [2, 3], [2, 4], [3, 4], [2, 3, 4]])

    def forward(self, img):
        if self.args.benchmark:
            return self.model(img)
        return self.tta_inference(img) if self.args.tta else self.do_inference(
            img)

    def training_step(self, batch, batch_idx):
        img, lbl = batch["image"], batch["label"]
        pred = self.model(img)
        loss = self.compute_loss(pred, lbl)
        return loss

    def validation_step(self, batch, batch_idx):
        img, lbl = batch["image"], batch["label"]
        pred = self.forward(img)
        loss = self.loss(pred, lbl)
        dice = self.dice(pred, lbl[:, 0])
        return {"val_loss": loss, "val_dice": dice}

    def test_step(self, batch, batch_idx):
        if self.args.exec_mode == "evaluate":
            return self.validation_step(batch, batch_idx)
        img = batch["image"]
        pred = self.forward(img)
        if self.args.save_preds:
            self.save_mask(pred, batch["fname"])

    def build_unet(self, in_channels, n_class, kernels, strides):
        return UNet(
            in_channels=in_channels,
            n_class=n_class,
            kernels=kernels,
            strides=strides,
            normalization_layer=self.args.norm,
            negative_slope=self.args.negative_slope,
            deep_supervision=self.args.deep_supervision,
            dimension=self.args.dim,
        )

    def get_unet_params(self):
        config = get_config_file(self.args)
        in_channels = config["in_channels"]
        patch_size = config["patch_size"]
        spacings = config["spacings"]
        n_class = config["n_class"]

        strides, kernels, sizes = [], [], patch_size[:]
        while True:
            spacing_ratio = [spacing / min(spacings) for spacing in spacings]
            stride = [
                2 if ratio <= 2 and size >= 8 else 1
                for (ratio, size) in zip(spacing_ratio, sizes)
            ]
            kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
            if all(s == 1 for s in stride):
                break
            sizes = [i / j for i, j in zip(sizes, stride)]
            spacings = [i * j for i, j in zip(spacings, stride)]
            kernels.append(kernel)
            strides.append(stride)
            if len(strides) == 5:
                break
        strides.insert(0, len(spacings) * [1])
        kernels.append(len(spacings) * [3])

        return in_channels, n_class, kernels, strides, patch_size

    def build_nnunet(self):
        in_channels, n_class, kernels, strides, self.patch_size = self.get_unet_params(
        )
        self.model = self.build_unet(in_channels, n_class, kernels, strides)
        self.n_class = n_class - 1
        if is_main_process():
            print(f"Filters: {self.model.filters}")
            print(f"Kernels: {kernels}")
            print(f"Strides: {strides}")

    def compute_loss(self, preds, label):
        if self.args.deep_supervision:
            loss = self.loss(preds[0], label)
            for i, pred in enumerate(preds[1:]):
                downsampled_label = nn.functional.interpolate(
                    label, pred.shape[2:])
                loss += 0.5**(i + 1) * self.loss(pred, downsampled_label)
            c_norm = 1 / (2 - 2**(-len(preds)))
            return c_norm * loss
        return self.loss(preds, label)

    def do_inference(self, image):
        if self.args.dim == 2:
            if self.args.exec_mode == "predict" and not self.args.benchmark:
                return self.inference2d_test(image)
            return self.inference2d(image)

        return self.sliding_window_inference(image)

    def tta_inference(self, img):
        pred = self.do_inference(img)
        for flip_idx in self.tta_flips:
            pred += flip(self.do_inference(flip(img, flip_idx)), flip_idx)
        pred /= len(self.tta_flips) + 1
        return pred

    def inference2d(self, image):
        batch_modulo = image.shape[2] % self.args.val_batch_size
        if self.args.benchmark:
            image = image[:, :, batch_modulo:]
        elif batch_modulo != 0:
            batch_pad = self.args.val_batch_size - batch_modulo
            image = nn.ConstantPad3d((0, 0, 0, 0, batch_pad, 0), 0)(image)

        image = torch.transpose(image.squeeze(0), 0, 1)
        preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:])
        preds = torch.zeros(preds_shape,
                            dtype=image.dtype,
                            device=image.device)
        for start in range(0, image.shape[0] - self.args.val_batch_size + 1,
                           self.args.val_batch_size):
            end = start + self.args.val_batch_size
            pred = self.model(image[start:end])
            preds[start:end] = pred.data

        if batch_modulo != 0 and not self.args.benchmark:
            preds = preds[batch_pad:]

        return torch.transpose(preds, 0, 1).unsqueeze(0)

    def inference2d_test(self, image):
        preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:])
        preds = torch.zeros(preds_shape,
                            dtype=image.dtype,
                            device=image.device)
        for depth in range(image.shape[2]):
            preds[:, :, depth] = self.sliding_window_inference(image[:, :,
                                                                     depth])
        return preds

    def sliding_window_inference(self, image):
        return sliding_window_inference(
            inputs=image,
            roi_size=self.patch_size,
            sw_batch_size=self.args.val_batch_size,
            predictor=self.model,
            overlap=self.args.overlap,
            mode=self.args.val_mode,
        )

    @staticmethod
    def metric_mean(name, outputs):
        return torch.stack([out[name] for out in outputs]).mean(dim=0)

    def validation_epoch_end(self, outputs):
        loss = self.metric_mean("val_loss", outputs)
        dice = 100 * self.metric_mean("val_dice", outputs)
        dice_sum = torch.sum(dice)
        if dice_sum >= self.best_sum:
            self.best_sum = dice_sum
            self.best_sum_dice = dice[:]
            self.best_sum_epoch = self.current_epoch
        for i, dice_i in enumerate(dice):
            if dice_i > self.best_dice[i]:
                self.best_dice[i], self.best_epoch[
                    i] = dice_i, self.current_epoch

        if is_main_process():
            metrics = {}
            metrics.update({"mean dice": round(torch.mean(dice).item(), 2)})
            metrics.update(
                {"TOP_mean": round(torch.mean(self.best_sum_dice).item(), 2)})
            metrics.update(
                {f"L{i+1}": round(m.item(), 2)
                 for i, m in enumerate(dice)})
            metrics.update({
                f"TOP_L{i+1}": round(m.item(), 2)
                for i, m in enumerate(self.best_sum_dice)
            })
            metrics.update({"val_loss": round(loss.item(), 4)})
            self.dllogger.log(step=self.current_epoch, data=metrics)
            self.dllogger.flush()

        self.log("val_loss", loss)
        self.log("dice_sum", dice_sum)

    def test_epoch_end(self, outputs):
        if self.args.exec_mode == "evaluate":
            self.eval_dice = 100 * self.metric_mean("val_dice", outputs)

    def configure_optimizers(self):
        optimizer = {
            "sgd":
            torch.optim.SGD(self.parameters(),
                            lr=self.learning_rate,
                            momentum=self.args.momentum),
            "adam":
            torch.optim.Adam(self.parameters(),
                             lr=self.learning_rate,
                             weight_decay=self.args.weight_decay),
            "adamw":
            torch.optim.AdamW(self.parameters(),
                              lr=self.learning_rate,
                              weight_decay=self.args.weight_decay),
            "radam":
            optim.RAdam(self.parameters(),
                        lr=self.learning_rate,
                        weight_decay=self.args.weight_decay),
            "fused_adam":
            apex.optimizers.FusedAdam(self.parameters(),
                                      lr=self.learning_rate,
                                      weight_decay=self.args.weight_decay),
        }[self.args.optimizer.lower()]

        scheduler = {
            "none":
            None,
            "multistep":
            torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 self.args.steps,
                                                 gamma=self.args.factor),
            "cosine":
            torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       self.args.max_epochs),
            "plateau":
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                factor=self.args.factor,
                patience=self.args.lr_patience),
        }[self.args.scheduler.lower()]

        opt_dict = {"optimizer": optimizer, "monitor": "val_loss"}
        if scheduler is not None:
            opt_dict.update({"lr_scheduler": scheduler})
        return opt_dict

    def save_mask(self, pred, fname):
        fname = str(fname[0].cpu().detach().numpy(),
                    "utf-8").replace("_x", "_pred")
        pred = nn.functional.softmax(torch.tensor(pred), dim=1)
        pred = pred.squeeze().cpu().detach().numpy()
        np.save(os.path.join(self.save_dir, fname), pred, allow_pickle=False)
Beispiel #11
0
def format_step(step):
    if isinstance(step, str):
        return step
    s = ""
    if len(step) > 0:
        s += "Epoch: {} ".format(step[0])
    if len(step) > 1:
        s += "Iteration: {} ".format(step[1])
    if len(step) > 2:
        s += "Validation Iteration: {} ".format(step[2])
    return s


l = Logger(
    [
        StdOutBackend(Verbosity.DEFAULT, step_format=format_step),
        JSONStreamBackend(Verbosity.VERBOSE, "tmp.json"),
    ]
)

# You can log metrics in separate calls
l.log(step="PARAMETER", data={"HP1": 17}, verbosity=Verbosity.DEFAULT)
l.log(step="PARAMETER", data={"HP2": 23}, verbosity=Verbosity.DEFAULT)
# or together
l.log(step="PARAMETER", data={"HP3": 1, "HP4": 2}, verbosity=Verbosity.DEFAULT)

l.metadata("loss", {"unit": "nat", "GOAL": "MINIMIZE", "STAGE": "TRAIN"})
l.metadata("val.loss", {"unit": "nat", "GOAL": "MINIMIZE", "STAGE": "VAL"})
l.metadata(
    "speed",
    {"unit": "speeds/s", "format": ":.3f", "GOAL": "MAXIMIZE", "STAGE": "TRAIN"},
)