示例#1
0
    def __init__(
        self,
        data_path: str,
        batch_size: int = 4,
        lr: float = 1e-3,
        num_layers: int = 3,
        features_start: int = 64,
        bilinear: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.data_path = data_path
        self.batch_size = batch_size
        self.lr = lr
        self.num_layers = num_layers
        self.features_start = features_start
        self.bilinear = bilinear

        self.net = UNet(num_classes=19,
                        num_layers=self.num_layers,
                        features_start=self.features_start,
                        bilinear=self.bilinear)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
                                 std=[0.32064945, 0.32098866, 0.32325324])
        ])
        self.trainset = KITTI(self.data_path,
                              split='train',
                              transform=self.transform)
        self.validset = KITTI(self.data_path,
                              split='valid',
                              transform=self.transform)
示例#2
0
 def build(self) -> nn.Module:
     return UNet(
         num_classes=self.num_classes,
         num_layers=self.num_layers,
         features_start=self.features_start,
         bilinear=self.bilinear,
     )
示例#3
0
class SegModel(pl.LightningModule):
    """
    Semantic Segmentation Module

    This is a basic semantic segmentation module implemented with Lightning.
    It uses CrossEntropyLoss as the default loss function. May be replaced with
    other loss functions as required.
    It is specific to KITTI dataset i.e. dataloaders are for KITTI
    and Normalize transform uses the mean and standard deviation of this dataset.
    It uses the FCN ResNet50 model as an example.

    Adam optimizer is used along with Cosine Annealing learning rate scheduler.

    >>> from pl_examples import _DATASETS_PATH
    >>> dataset_path = os.path.join(_DATASETS_PATH, "Kitti")
    >>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
    >>> SegModel(dataset_path)  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    SegModel(
      (net): UNet(
        (layers): ModuleList(
          (0): DoubleConv(...)
          (1): Down(...)
          (2): Down(...)
          (3): Up(...)
          (4): Up(...)
          (5): Conv2d(64, 19, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
    """
    def __init__(
        self,
        data_path: str,
        batch_size: int = 4,
        lr: float = 1e-3,
        num_layers: int = 3,
        features_start: int = 64,
        bilinear: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.data_path = data_path
        self.batch_size = batch_size
        self.lr = lr
        self.num_layers = num_layers
        self.features_start = features_start
        self.bilinear = bilinear

        self.net = UNet(num_classes=19,
                        num_layers=self.num_layers,
                        features_start=self.features_start,
                        bilinear=self.bilinear)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
                                 std=[0.32064945, 0.32098866, 0.32325324])
        ])
        self.trainset = KITTI(self.data_path,
                              split='train',
                              transform=self.transform)
        self.validset = KITTI(self.data_path,
                              split='valid',
                              transform=self.transform)

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_nb):
        img, mask = batch
        img = img.float()
        mask = mask.long()
        out = self(img)
        loss = F.cross_entropy(out, mask, ignore_index=250)
        log_dict = {'train_loss': loss}
        return {'loss': loss, 'log': log_dict, 'progress_bar': log_dict}

    def validation_step(self, batch, batch_idx):
        img, mask = batch
        img = img.float()
        mask = mask.long()
        out = self(img)
        loss_val = F.cross_entropy(out, mask, ignore_index=250)
        return {'val_loss': loss_val}

    def validation_epoch_end(self, outputs):
        loss_val = torch.stack([x['val_loss'] for x in outputs]).mean()
        log_dict = {'val_loss': loss_val}
        return {
            'log': log_dict,
            'val_loss': log_dict['val_loss'],
            'progress_bar': log_dict
        }

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
        return [opt], [sch]

    def train_dataloader(self):
        return DataLoader(self.trainset,
                          batch_size=self.batch_size,
                          shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.validset,
                          batch_size=self.batch_size,
                          shuffle=False)

    @staticmethod
    def add_model_specific_args(parent_parser):  # pragma: no-cover
        parser = parent_parser.add_argument_group("SegModel")
        parser.add_argument("--data_path",
                            type=str,
                            help="path where dataset is stored")
        parser.add_argument("--batch_size",
                            type=int,
                            default=16,
                            help="size of the batches")
        parser.add_argument("--lr",
                            type=float,
                            default=0.001,
                            help="adam: learning rate")
        parser.add_argument("--num_layers",
                            type=int,
                            default=5,
                            help="number of layers on u-net")
        parser.add_argument("--features_start",
                            type=float,
                            default=64,
                            help="number of features in first layer")
        parser.add_argument(
            "--bilinear",
            action='store_true',
            default=False,
            help="whether to use bilinear interpolation or transposed")
        return parent_parser
示例#4
0
class SegModel(pl.LightningModule):
    """
    Semantic Segmentation Module

    This is a basic semantic segmentation module implemented with Lightning.
    It uses CrossEntropyLoss as the default loss function. May be replaced with
    other loss functions as required.
    It is specific to KITTI dataset i.e. dataloaders are for KITTI
    and Normalize transform uses the mean and standard deviation of this dataset.
    It uses the FCN ResNet50 model as an example.

    Adam optimizer is used along with Cosine Annealing learning rate scheduler.
    """
    def __init__(self, data_path: str, batch_size: int, lr: float,
                 num_layers: int, features_start: int, bilinear: bool,
                 **kwargs):
        super().__init__()
        self.data_path = data_path
        self.batch_size = batch_size
        self.lr = lr
        self.num_layers = num_layers
        self.features_start = features_start
        self.bilinear = bilinear

        self.net = UNet(num_classes=19,
                        num_layers=self.num_layers,
                        features_start=self.features_start,
                        bilinear=self.bilinear)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
                                 std=[0.32064945, 0.32098866, 0.32325324])
        ])
        self.trainset = KITTI(self.data_path,
                              split='train',
                              transform=self.transform)
        self.validset = KITTI(self.data_path,
                              split='valid',
                              transform=self.transform)

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_nb):
        img, mask = batch
        img = img.float()
        mask = mask.long()
        out = self(img)
        loss_val = F.cross_entropy(out, mask, ignore_index=250)
        log_dict = {'train_loss': loss_val}
        return {'loss': loss_val, 'log': log_dict, 'progress_bar': log_dict}

    def validation_step(self, batch, batch_idx):
        img, mask = batch
        img = img.float()
        mask = mask.long()
        out = self(img)
        loss_val = F.cross_entropy(out, mask, ignore_index=250)
        return {'val_loss': loss_val}

    def validation_epoch_end(self, outputs):
        loss_val = torch.stack([x['val_loss'] for x in outputs]).mean()
        log_dict = {'val_loss': loss_val}
        return {
            'log': log_dict,
            'val_loss': log_dict['val_loss'],
            'progress_bar': log_dict
        }

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
        return [opt], [sch]

    def train_dataloader(self):
        return DataLoader(self.trainset,
                          batch_size=self.batch_size,
                          shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.validset,
                          batch_size=self.batch_size,
                          shuffle=False)
示例#5
0
class SegModel(pl.LightningModule):
    def __init__(
        self,
        data_path: str,
        batch_size: int = 4,
        lr: float = 1e-3,
        num_layers: int = 3,
        features_start: int = 64,
        bilinear: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.data_path = data_path
        self.batch_size = batch_size
        self.lr = lr
        self.num_layers = num_layers
        self.features_start = features_start
        self.bilinear = bilinear

        self.net = UNet(
            num_classes=1,
            num_layers=self.num_layers,
            features_start=self.features_start,
            bilinear=self.bilinear,
        )
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.35675976, 0.37380189, 0.3764753],
                std=[0.32064945, 0.32098866, 0.32325324],
            ),
        ])
        self.trainset = Cisti(self.data_path,
                              split="train",
                              transform=self.transform)
        self.validset = Cisti(self.data_path,
                              split="valid",
                              transform=self.transform)

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_nb):
        img, mask = batch
        img = img.float()
        mask = mask.long()
        out = self.forward(img)
        loss_val = F.cross_entropy(out, mask)
        log_dict = {"train_loss": loss_val}
        return {"loss": loss_val, "log": log_dict, "progress_bar": log_dict}

    def validation_step(self, batch, batch_idx):
        img, mask = batch
        img = img.float()
        mask = mask.long()
        print(mask)
        out = self.forward(img)
        loss_val = F.cross_entropy(out, mask)
        return {"val_loss": loss_val}

    def validation_epoch_end(self, outputs):
        loss_val = torch.stack([x["val_loss"] for x in outputs]).mean()
        log_dict = {"val_loss": loss_val}
        return {
            "log": log_dict,
            "val_loss": log_dict["val_loss"],
            "progress_bar": log_dict,
        }

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
        return [opt], [sch]

    def train_dataloader(self):
        return DataLoader(self.trainset,
                          batch_size=self.batch_size,
                          shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.validset,
                          batch_size=self.batch_size,
                          shuffle=False)