def log(logname, dice, results="/results"): dllogger = Logger(backends=[ JSONStreamBackend(Verbosity.VERBOSE, os.path.join(results, logname)), StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: ""), ]) metrics = {} metrics.update({"Mean dice": round(dice.mean().item(), 2)}) metrics.update({f"L{j+1}": round(m.item(), 2) for j, m in enumerate(dice)}) dllogger.log(step=(), data=metrics) dllogger.flush()
def log(logname, dice, epoch=None, dice_tta=None): dllogger = Logger(backends=[ JSONStreamBackend(Verbosity.VERBOSE, os.path.join( args.results, logname)), StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: ""), ]) metrics = {} if epoch is not None: metrics.update({"Epoch": epoch}) metrics.update({"Mean dice": round(dice.mean().item(), 2)}) if dice_tta is not None: metrics.update({"Mean TTA dice": round(dice_tta.mean().item(), 2)}) metrics.update({f"L{j+1}": round(m.item(), 2) for j, m in enumerate(dice)}) if dice_tta is not None: metrics.update({ f"TTA_L{j+1}": round(m.item(), 2) for j, m in enumerate(dice_tta) }) dllogger.log(step=(), data=metrics) dllogger.flush()
class Model(pl.LightningModule): def __init__(self, args): super(Model, self).__init__() self.save_hyperparameters() self.args = args self.f1_score = F1(args) self.model = UNetLoc(args) if args.type == "pre" else get_dmg_unet( args) self.loss = Loss(args) self.best_f1 = torch.tensor(0) self.best_epoch = 0 self.tta_flips = [[2], [3], [2, 3]] self.lr = args.lr self.n_class = 2 if self.args.type == "pre" else 5 self.softmax = nn.Softmax(dim=1) self.test_idx = 0 self.dllogger = Logger(backends=[ JSONStreamBackend( Verbosity.VERBOSE, os.path.join(args.results, f"{args.logname}.json")), StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: f"Epoch: {step} "), ]) def forward(self, img): pred = self.model(img) if self.args.tta: for flip_idx in self.tta_flips: pred += self.flip(self.model(self.flip(img, flip_idx)), flip_idx) pred /= len(self.tta_flips) + 1 return pred def training_step(self, batch, _): img, lbl = batch["image"], batch["mask"] pred = self.model(img) loss = self.compute_loss(pred, lbl) return loss def validation_step(self, batch, _): img, lbl = batch["image"], batch["mask"] pred = self.forward(img) loss = self.loss(pred, lbl) self.f1_score.update(pred, lbl) return {"val_loss": loss} def test_step(self, batch, batch_idx): img, lbl = batch["image"], batch["mask"] pred = self.forward(img) self.f1_score.update(pred, lbl) self.save(pred, lbl) def compute_loss(self, preds, label): if self.args.deep_supervision: loss = self.loss(preds[0], label) for i, pred in enumerate(preds[1:]): downsampled_label = torch.nn.functional.interpolate( label.unsqueeze(1), pred.shape[2:]) loss += 0.5**(i + 1) * self.loss(pred, downsampled_label.squeeze(1)) c_norm = 1 / (2 - 2**(-len(preds))) return c_norm * loss return self.loss(preds, label) @staticmethod def metric_mean(name, outputs): return torch.stack([out[name] for out in outputs]).mean(dim=0) @staticmethod def update_damage_scores(metrics, dmgs_f1): if dmgs_f1 is not None: for i in range(4): metrics.update({f"D{i+1}": round(dmgs_f1[i].item(), 3)}) def on_validation_epoch_start(self): self.f1_score.reset() def on_test_epoch_start(self): self.f1_score.reset() def validation_epoch_end(self, outputs): loss = self.metric_mean("val_loss", outputs) f1_score, dmgs_f1 = self.f1_score.compute() self.f1_score.reset() if f1_score >= self.best_f1: self.best_f1 = f1_score self.best_epoch = self.current_epoch if int(os.getenv("LOCAL_RANK", "0")) == 0: metrics = { "f1": round(f1_score.item(), 3), "val_loss": round(loss.item(), 3), "top_f1": round(self.best_f1.item(), 3), } self.update_damage_scores(metrics, dmgs_f1) self.dllogger.log(step=self.current_epoch, data=metrics) self.dllogger.flush() self.log("f1_score", f1_score.cpu()) self.log("val_loss", loss.cpu()) def test_epoch_end(self, _): f1_score, dmgs_f1 = self.f1_score.compute() self.f1_score.reset() if int(os.getenv("LOCAL_RANK", "0")) == 0: metrics = {"f1": round(f1_score.item(), 3)} self.update_damage_scores(metrics, dmgs_f1) self.dllogger.log(step=(), data=metrics) self.dllogger.flush() def save(self, preds, targets): if self.args.type == "pre": probs = torch.sigmoid(preds[:, 1]) else: if self.args.loss_str == "coral": probs = torch.sum(torch.sigmoid(preds) > 0.5, dim=1) + 1 elif self.args.loss_str == "mse": probs = torch.round(F.relu(preds[:, 0], inplace=True)) + 1 else: probs = self.softmax(preds) probs = probs.cpu().detach().numpy() targets = targets.cpu().detach().numpy().astype(np.uint8) for prob, target in zip(probs, targets): task = "localization" if self.args.type == "pre" else "damage" fname = os.path.join(self.args.results, "probs", f"test_{task}_{self.test_idx:05d}") self.test_idx += 1 np.save(fname, prob) Image.fromarray(target).save( fname.replace("probs", "targets") + "_target.png") @staticmethod def flip(data, axis): return torch.flip(data, dims=axis) def configure_optimizers(self): optimizer = { "sgd": FusedSGD(self.parameters(), lr=self.lr, momentum=self.args.momentum), "adam": FusedAdam(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), "adamw": torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), "radam": RAdam(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), "adabelief": AdaBelief(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), "adabound": AdaBound(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), "adamp": AdamP(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), "novograd": FusedNovoGrad(self.parameters(), lr=self.lr, weight_decay=self.args.weight_decay), }[self.args.optimizer.lower()] if not self.args.use_scheduler: return optimizer scheduler = { "scheduler": NoamLR( optimizer=optimizer, warmup_epochs=self.args.warmup, total_epochs=self.args.epochs, steps_per_epoch=len(self.train_dataloader()) // self.args.gpus, init_lr=self.args.init_lr, max_lr=self.args.lr, final_lr=self.args.final_lr, ), "interval": "step", "frequency": 1, } return {"optimizer": optimizer, "lr_scheduler": scheduler} @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) arg = parser.add_argument arg( "--optimizer", type=str, default="adamw", choices=[ "sgd", "adam", "adamw", "radam", "adabelief", "adabound", "adamp", "novograd" ], ) arg( "--dmg_model", type=str, default="siamese", choices=[ "siamese", "siameseEnc", "fused", "fusedEnc", "parallel", "parallelEnc", "diff", "cat" ], help="U-Net variant for damage assessment task", ) arg( "--encoder", type=str, default="resnest200", choices=[ "resnest50", "resnest101", "resnest200", "resnest269", "resnet50", "resnet101", "resnet152" ], help="U-Net encoder", ) arg( "--loss_str", type=str, default="focal+dice", help= "Combination of: dice, focal, ce, ohem, mse, coral, e.g focal+dice creates the loss function as sum of focal and dice", ) arg("--use_scheduler", action="store_true", help="Enable Noam learning rate scheduler") arg("--warmup", type=int, default=1, help="Warmup epochs for Noam learning rate scheduler") arg("--init_lr", type=float, default=1e-4, help="Initial learning rate for Noam scheduler") arg("--final_lr", type=float, default=1e-4, help="Final learning rate for Noam scheduler") arg("--lr", type=float, default=3e-4, help="Learning rate, or a target learning rate for Noam scheduler") arg("--weight_decay", type=float, default=0, help="Weight decay (L2 penalty)") arg("--momentum", type=float, default=0.9, help="Momentum for SGD optimizer") arg( "--dilation", type=int, choices=[1, 2, 4], default=1, help= "Dilation rate for a encoder, e.g dilation=2 uses dilation instead of stride in the last encoder block", ) arg("--tta", action="store_true", help="Enable test time augmentation") arg("--ppm", action="store_true", help="Use pyramid pooling module") arg("--aspp", action="store_true", help="Use atrous spatial pyramid pooling") arg("--no_skip", action="store_true", help="Disable skip connections in UNet") arg("--deep_supervision", action="store_true", help="Enable deep supervision") arg("--attention", action="store_true", help="Enable attention module at the decoder") arg("--autoaugment", action="store_true", help="Use imageNet autoaugment pipeline") arg("--interpolate", action="store_true", help="Interpolate feature map from encoder without a decoder") arg("--dec_interp", action="store_true", help= "Use interpolation instead of transposed convolution in a decoder") return parser
class NNUnet(pl.LightningModule): def __init__(self, args): super(NNUnet, self).__init__() self.args = args self.save_hyperparameters() self.build_nnunet() self.loss = Loss() self.dice = Dice(self.n_class) self.best_sum = 0 self.eval_dice = 0 self.best_sum_epoch = 0 self.best_dice = self.n_class * [0] self.best_epoch = self.n_class * [0] self.best_sum_dice = self.n_class * [0] self.learning_rate = args.learning_rate if self.args.exec_mode in ["train", "evaluate"]: self.dllogger = Logger(backends=[ JSONStreamBackend(Verbosity.VERBOSE, os.path.join(args.results, "logs.json")), StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: f"Epoch: {step} "), ]) self.tta_flips = ([[2], [3], [2, 3]] if self.args.dim == 2 else [[2], [3], [4], [2, 3], [2, 4], [3, 4], [2, 3, 4]]) def forward(self, img): if self.args.benchmark: return self.model(img) return self.tta_inference(img) if self.args.tta else self.do_inference( img) def training_step(self, batch, batch_idx): img, lbl = batch["image"], batch["label"] pred = self.model(img) loss = self.compute_loss(pred, lbl) return loss def validation_step(self, batch, batch_idx): img, lbl = batch["image"], batch["label"] pred = self.forward(img) loss = self.loss(pred, lbl) dice = self.dice(pred, lbl[:, 0]) return {"val_loss": loss, "val_dice": dice} def test_step(self, batch, batch_idx): if self.args.exec_mode == "evaluate": return self.validation_step(batch, batch_idx) img = batch["image"] pred = self.forward(img) if self.args.save_preds: self.save_mask(pred, batch["fname"]) def build_unet(self, in_channels, n_class, kernels, strides): return UNet( in_channels=in_channels, n_class=n_class, kernels=kernels, strides=strides, normalization_layer=self.args.norm, negative_slope=self.args.negative_slope, deep_supervision=self.args.deep_supervision, dimension=self.args.dim, ) def get_unet_params(self): config = get_config_file(self.args) in_channels = config["in_channels"] patch_size = config["patch_size"] spacings = config["spacings"] n_class = config["n_class"] strides, kernels, sizes = [], [], patch_size[:] while True: spacing_ratio = [spacing / min(spacings) for spacing in spacings] stride = [ 2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes) ] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) if len(strides) == 5: break strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) return in_channels, n_class, kernels, strides, patch_size def build_nnunet(self): in_channels, n_class, kernels, strides, self.patch_size = self.get_unet_params( ) self.model = self.build_unet(in_channels, n_class, kernels, strides) self.n_class = n_class - 1 if is_main_process(): print(f"Filters: {self.model.filters}") print(f"Kernels: {kernels}") print(f"Strides: {strides}") def compute_loss(self, preds, label): if self.args.deep_supervision: loss = self.loss(preds[0], label) for i, pred in enumerate(preds[1:]): downsampled_label = nn.functional.interpolate( label, pred.shape[2:]) loss += 0.5**(i + 1) * self.loss(pred, downsampled_label) c_norm = 1 / (2 - 2**(-len(preds))) return c_norm * loss return self.loss(preds, label) def do_inference(self, image): if self.args.dim == 2: if self.args.exec_mode == "predict" and not self.args.benchmark: return self.inference2d_test(image) return self.inference2d(image) return self.sliding_window_inference(image) def tta_inference(self, img): pred = self.do_inference(img) for flip_idx in self.tta_flips: pred += flip(self.do_inference(flip(img, flip_idx)), flip_idx) pred /= len(self.tta_flips) + 1 return pred def inference2d(self, image): batch_modulo = image.shape[2] % self.args.val_batch_size if self.args.benchmark: image = image[:, :, batch_modulo:] elif batch_modulo != 0: batch_pad = self.args.val_batch_size - batch_modulo image = nn.ConstantPad3d((0, 0, 0, 0, batch_pad, 0), 0)(image) image = torch.transpose(image.squeeze(0), 0, 1) preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:]) preds = torch.zeros(preds_shape, dtype=image.dtype, device=image.device) for start in range(0, image.shape[0] - self.args.val_batch_size + 1, self.args.val_batch_size): end = start + self.args.val_batch_size pred = self.model(image[start:end]) preds[start:end] = pred.data if batch_modulo != 0 and not self.args.benchmark: preds = preds[batch_pad:] return torch.transpose(preds, 0, 1).unsqueeze(0) def inference2d_test(self, image): preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:]) preds = torch.zeros(preds_shape, dtype=image.dtype, device=image.device) for depth in range(image.shape[2]): preds[:, :, depth] = self.sliding_window_inference(image[:, :, depth]) return preds def sliding_window_inference(self, image): return sliding_window_inference( inputs=image, roi_size=self.patch_size, sw_batch_size=self.args.val_batch_size, predictor=self.model, overlap=self.args.overlap, mode=self.args.val_mode, ) @staticmethod def metric_mean(name, outputs): return torch.stack([out[name] for out in outputs]).mean(dim=0) def validation_epoch_end(self, outputs): loss = self.metric_mean("val_loss", outputs) dice = 100 * self.metric_mean("val_dice", outputs) dice_sum = torch.sum(dice) if dice_sum >= self.best_sum: self.best_sum = dice_sum self.best_sum_dice = dice[:] self.best_sum_epoch = self.current_epoch for i, dice_i in enumerate(dice): if dice_i > self.best_dice[i]: self.best_dice[i], self.best_epoch[ i] = dice_i, self.current_epoch if is_main_process(): metrics = {} metrics.update({"mean dice": round(torch.mean(dice).item(), 2)}) metrics.update( {"TOP_mean": round(torch.mean(self.best_sum_dice).item(), 2)}) metrics.update( {f"L{i+1}": round(m.item(), 2) for i, m in enumerate(dice)}) metrics.update({ f"TOP_L{i+1}": round(m.item(), 2) for i, m in enumerate(self.best_sum_dice) }) metrics.update({"val_loss": round(loss.item(), 4)}) self.dllogger.log(step=self.current_epoch, data=metrics) self.dllogger.flush() self.log("val_loss", loss) self.log("dice_sum", dice_sum) def test_epoch_end(self, outputs): if self.args.exec_mode == "evaluate": self.eval_dice = 100 * self.metric_mean("val_dice", outputs) def configure_optimizers(self): optimizer = { "sgd": torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=self.args.momentum), "adam": torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay), "adamw": torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay), "radam": optim.RAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay), "fused_adam": apex.optimizers.FusedAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay), }[self.args.optimizer.lower()] scheduler = { "none": None, "multistep": torch.optim.lr_scheduler.MultiStepLR(optimizer, self.args.steps, gamma=self.args.factor), "cosine": torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.args.max_epochs), "plateau": torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=self.args.factor, patience=self.args.lr_patience), }[self.args.scheduler.lower()] opt_dict = {"optimizer": optimizer, "monitor": "val_loss"} if scheduler is not None: opt_dict.update({"lr_scheduler": scheduler}) return opt_dict def save_mask(self, pred, fname): fname = str(fname[0].cpu().detach().numpy(), "utf-8").replace("_x", "_pred") pred = nn.functional.softmax(torch.tensor(pred), dim=1) pred = pred.squeeze().cpu().detach().numpy() np.save(os.path.join(self.save_dir, fname), pred, allow_pickle=False)
l.log(step="PARAMETER", data={"HP1": 17}, verbosity=Verbosity.DEFAULT) l.log(step="PARAMETER", data={"HP2": 23}, verbosity=Verbosity.DEFAULT) # or together l.log(step="PARAMETER", data={"HP3": 1, "HP4": 2}, verbosity=Verbosity.DEFAULT) l.metadata("loss", {"unit": "nat", "GOAL": "MINIMIZE", "STAGE": "TRAIN"}) l.metadata("val.loss", {"unit": "nat", "GOAL": "MINIMIZE", "STAGE": "VAL"}) l.metadata( "speed", {"unit": "speeds/s", "format": ":.3f", "GOAL": "MAXIMIZE", "STAGE": "TRAIN"}, ) for epoch in range(0, 2): for it in range(0, 10): l.log( step=(epoch, it), data={"loss": 130 / (1 + epoch * 10 + it)}, verbosity=Verbosity.DEFAULT, ) if it % 3 == 0: for vit in range(0, 3): l.log( step=(epoch, it, vit), data={"val.loss": 230 / (1 + epoch * 10 + it + vit)}, verbosity=Verbosity.DEFAULT, ) l.log(step=(epoch,), data={"speed": 10}, verbosity=Verbosity.DEFAULT) l.flush()