Esempio n. 1
0
class UnetEfficientnet(GeoTiffPredictionMixin, pl.LightningModule):
    def __init__(self, hparams):
        """hparams must be a dict of {weight_decay, lr, num_classes}"""
        super().__init__()
        self.save_hyperparameters(hparams)

        # Create model from pre-trained DeepLabv3
        self.model = Unet(
            encoder_name="efficientnet-b4",
            encoder_weights="imagenet",
            in_channels=3,
            classes=self.hparams.num_classes,
        )
        self.model.requires_grad_(True)
        self.model.encoder.requires_grad_(False)

        # Loss function and metrics
        self.focal_tversky_loss = FocalTverskyMetric(
            self.hparams.num_classes,
            alpha=0.7,
            beta=0.3,
            gamma=4.0 / 3.0,
            ignore_index=self.hparams.get("ignore_index"),
        )
        self.accuracy_metric = Accuracy(
            ignore_index=self.hparams.get("ignore_index"))
        self.iou_metric = JaccardIndex(
            num_classes=self.hparams.num_classes,
            reduction="none",
            ignore_index=self.hparams.get("ignore_index"),
        )
        self.precision_metric = Precision(num_classes=self.num_classes,
                                          ignore_index=self.ignore_index,
                                          average='weighted',
                                          mdmc_average='global')
        self.recall_metric = Recall(num_classes=self.num_classes,
                                    ignore_index=self.ignore_index,
                                    average='weighted',
                                    mdmc_average='global')

    @property
    def example_input_array(self) -> Any:
        return torch.rand(2, 3, 512, 512)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model.forward(x)

    def configure_optimizers(self):
        """Init optimizer and scheduler"""
        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
        )
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.hparams.max_epochs)
        return [optimizer], [{"scheduler": lr_scheduler, "interval": "epoch"}]

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        probs = torch.softmax(logits, dim=1)
        loss = self.focal_tversky_loss(probs, y)

        preds = logits.argmax(dim=1)
        ious = self.iou_metric(preds, y)
        acc = self.accuracy_metric(preds, y)

        self.log("train_loss", loss, on_epoch=True, sync_dist=True)
        self.log("train_miou", ious.mean(), on_epoch=True, sync_dist=True)
        self.log("train_accuracy", acc, on_epoch=True, sync_dist=True)
        for c in range(len(ious)):
            self.log(f"train_c{c}_iou", ious[c], on_epoch=True, sync_dist=True)

        return loss

    def val_test_step(self, batch, batch_idx, phase="val"):
        x, y = batch
        logits = self.model(x)
        probs = torch.softmax(logits, dim=1)
        loss = self.focal_tversky_loss(probs, y)

        preds = logits.argmax(dim=1)
        ious = self.iou_metric(preds, y)
        miou = ious.mean()
        acc = self.accuracy_metric(preds, y)
        precision = self.precision_metric(preds, y)
        recall = self.recall_metric(preds, y)

        if phase == 'val':
            self.log(f"hp_metric", miou)
        self.log(f"{phase}_loss", loss, sync_dist=True)
        self.log(f"{phase}_miou", miou, sync_dist=True)
        self.log(f"{phase}_accuracy", acc, sync_dist=True)
        self.log(f"{phase}_precision", precision, sync_dist=True)
        self.log(f"{phase}_recall", recall, sync_dist=True)
        for c in range(len(ious)):
            self.log(f"{phase}_cls{c}_iou", ious[c], sync_dist=True)

        return loss

    def validation_step(self, batch, batch_idx):
        return self.val_test_step(batch, batch_idx, phase="val")

    def test_step(self, batch, batch_idx):
        return self.val_test_step(batch, batch_idx, phase="test")

    @staticmethod
    def ckpt2pt(ckpt_file, pt_path):
        checkpoint = torch.load(ckpt_file, map_location=torch.device("cpu"))
        torch.save(checkpoint["state_dict"], pt_path)

    # @classmethod
    # def from_presence_absence_weights(cls, pt_weights_file, hparams):
    #     self = cls(hparams)
    #     weights = torch.load(pt_weights_file)
    #
    #     # Remove trained weights for previous classifier output layers
    #     del weights["model.classifier.4.weight"]
    #     del weights["model.classifier.4.bias"]
    #     del weights["model.aux_classifier.4.weight"]
    #     del weights["model.aux_classifier.4.bias"]
    #
    #     self.load_state_dict(weights, strict=False)
    #     return self

    @staticmethod
    def add_argparse_args(parser):
        group = parser.add_argument_group("UnetEfficientnet")

        group.add_argument(
            "--num_classes",
            type=int,
            default=2,
            help="The number of image classes, including background.",
        )
        group.add_argument("--lr",
                           type=float,
                           default=0.001,
                           help="the learning rate")
        group.add_argument(
            "--weight_decay",
            type=float,
            default=1e-3,
            help="The weight decay factor for L2 regularization.",
        )
        group.add_argument("--ignore_index",
                           type=int,
                           help="Label of any class to ignore.")
        group.add_argument(
            "--aux_loss_factor",
            type=float,
            default=0.3,
            help=
            "The proportion of loss backpropagated to classifier built only on early layers.",
        )

        return parser
#################### Model ####################

model = Unet(encoder_name='efficientnet-b3', classes=2)

#################### Solver ####################

num_epochs = 100

criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.1, 2.0]))

lr = 0.05
weight_decay = 5e-4
momentum = 0.9
nesterov = True
optimizer = optim.SGD(model.parameters(),
                      lr=1.0,
                      momentum=momentum,
                      weight_decay=weight_decay,
                      nesterov=nesterov)

le = len(train_loader)


def lambda_lr_scheduler(iteration, lr0, n, a):
    return lr0 * pow((1.0 - 1.0 * iteration / n), a)


lr_scheduler = lrs.LambdaLR(optimizer,
                            lr_lambda=partial(lambda_lr_scheduler,
                                              lr0=lr,