コード例 #1
0
    def __init__(self, weights: List[float], penalty="l2", loss_mult=None):
        super().__init__()
        self.weights = weights

        self.pred_maps_bce_loss = BCELoss(logit=False)
        self.pred_maps_dice_loss = DiceLoss()
        self.pred_maps_mse_loss = MSELoss()

        self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult)
        self.deform_mse_loss = MSELoss()

        self.label_dice_loss = DiceLoss()
        self.label_mse_loss = MSELoss()

        self.label_dice_loss_affine = DiceLoss()
        self.label_mse_loss_affine = MSELoss()

        self.template_dice_loss = DiceLoss()
        self.template_mse_loss = MSELoss()

        self.atlas_mse_loss = MSELoss()
        self.atlas_mse_loss_affine = MSELoss()

        # self.label_bce_loss = BCELoss(logit=False)
        # self.template_bce_loss = BCELoss(logit=False)
        self.epoch = 0
コード例 #2
0
 def __init__(self, template, penalty="l2", loss_mult=None, weight=0.01):
     super().__init__()
     self.mse_loss = MSELoss()
     self.bce_loss = BCELoss(logit=True)
     self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult)
     self.weight = weight
     if isinstance(template, np.ndarray):
         self.template = torch.from_numpy(
             template).float().cuda().unsqueeze(0)
     else:
         self.template = template
コード例 #3
0
ファイル: train.py プロジェクト: lisurui6/CMRSegment
def main():
    args = parse_args()
    if args.conf_path is None:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
    else:
        train_conf = ConfigFactory.parse_file(str(Path(args.conf_path)))

    if get_conf(train_conf, group="experiment",
                key="experiment_dir") is not None:
        experiment_dir = Path(
            get_conf(train_conf, group="experiment", key="experiment_dir"))
    else:
        experiment_dir = None
    config = ExperimentConfig(
        experiment_dir=experiment_dir,
        batch_size=get_conf(train_conf, group="experiment", key="batch_size"),
        num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"),
        gpu=get_conf(train_conf, group="experiment", key="gpu"),
        device=get_conf(train_conf, group="experiment", key="device"),
        num_workers=get_conf(train_conf, group="experiment",
                             key="num_workers"),
        pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"),
    )
    shutil.copy(str(TRAIN_CONF_PATH),
                str(config.experiment_dir.joinpath("train.conf")))
    network = FCN2DSegmentationModel(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
        up_conv_filter=get_conf(train_conf,
                                group="network",
                                key="up_conv_filter"),
        final_conv_filter=get_conf(train_conf,
                                   group="network",
                                   key="final_conv_filter"),
        feature_size=get_conf(train_conf, group="network", key="feature_size"))
    training_set, validation_set = construct_training_validation_dataset(
        DataConfig.from_conf(TRAIN_CONF_PATH),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="in_channels"))
    training_set.export(config.experiment_dir.joinpath("training_set.csv"))
    validation_set.export(config.experiment_dir.joinpath("validation_set.csv"))
    if get_conf(train_conf, group="optimizer", key="type") == "SGD":
        optimizer = torch.optim.SGD(
            network.parameters(),
            lr=get_conf(train_conf, group="optimizer", key="learning_rate"),
            momentum=get_conf(train_conf, group="optimizer", key="momentum"),
        )
    else:
        optimizer = torch.optim.Adam(network.parameters(),
                                     lr=get_conf(train_conf,
                                                 group="optimizer",
                                                 key="learning_rate"))
    if get_conf(train_conf, group="loss", key="type") == "FocalLoss":
        loss = FocalLoss(
            alpha=get_conf(train_conf, group="loss", key="alpha"),
            gamma=get_conf(train_conf, group="loss", key="gamma"),
            logits=True,
        )
    else:
        loss = BCELoss()
    experiment = Experiment(
        config=config,
        network=network,
        training_set=training_set,
        validation_set=validation_set,
        optimizer=optimizer,
        loss=loss,
        other_validation_metrics=[DiceCoeffWithLogits()],
    )
    experiment.train()
コード例 #4
0
ファイル: loss.py プロジェクト: lisurui6/CMRSegment
class DefSegLoss(TorchLoss):
    def __init__(self, weights: List[float], penalty="l2", loss_mult=None):
        super().__init__()
        self.weights = weights

        self.pred_maps_bce_loss = BCELoss(logit=False)
        self.pred_maps_dice_loss = DiceLoss()
        self.pred_maps_mse_loss = MSELoss()

        self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult)
        self.deform_mse_loss = MSELoss()

        self.label_dice_loss = DiceLoss()
        self.label_mse_loss = MSELoss()

        self.template_dice_loss = DiceLoss()
        self.template_mse_loss = MSELoss()

        # self.label_bce_loss = BCELoss(logit=False)
        # self.template_bce_loss = BCELoss(logit=False)
        self.epoch = 0

    def cumulate(
        self,
        predicted: Union[torch.Tensor, Iterable[torch.Tensor]],
        outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
    ):
        """predicted = (warped template, warped maps, pred maps, flow)"""
        label, template = outputs
        if self.epoch <= 10:
            weights = [1, 0, 0, 0, 0, 0, 0, 0, 0]
        else:
            weights = self.weights

        pred_map_bce_loss = self.pred_maps_bce_loss.cumulate(
            predicted[2], label)
        pred_map_dice_loss = self.pred_maps_dice_loss.cumulate(
            predicted[2], label)
        pred_map_mse_loss = self.pred_maps_mse_loss.cumulate(
            predicted[2], label)
        pred_map_loss = weights[0] * pred_map_bce_loss + weights[
            1] * pred_map_dice_loss + weights[2] * pred_map_mse_loss

        grad_loss = self.grad_loss.cumulate(predicted[3], None)
        deform_loss = self.deform_mse_loss.cumulate(
            predicted[3],
            torch.zeros(predicted[3].shape).cuda())

        label_dice_loss = self.label_dice_loss.cumulate(predicted[0], label)
        label_mse_loss = self.label_mse_loss.cumulate(predicted[0], label)
        label_loss = weights[3] * label_dice_loss + weights[4] * label_mse_loss

        template_dice_loss = self.template_dice_loss.cumulate(
            predicted[1], template)
        template_mse_loss = self.template_mse_loss.cumulate(
            predicted[1], template)
        template_loss = weights[5] * template_dice_loss + weights[
            6] * template_mse_loss

        loss = pred_map_loss + label_loss + template_loss + grad_loss * self.weights[
            7] + deform_loss * self.weights[8]
        self._cum_loss += loss.item()
        self._count += 1
        return loss

    def new(self):
        new_loss = self.__class__(penalty=self.grad_loss.penalty,
                                  loss_mult=self.grad_loss.loss_mult,
                                  weights=self.weights)
        new_loss.reset()
        return new_loss

    def description(self):
        return "total {:.4f}, pred map {}, pred map {}, pred map {}, label {}, label {}, grad {}, deform {}, ".format(
            self.log(),
            self.pred_maps_bce_loss.description(),
            self.pred_maps_dice_loss.description(),
            self.pred_maps_mse_loss.description(),
            self.label_dice_loss.description(),
            self.label_mse_loss.description(),
            # self.template_dice_loss.description(), self.template_mse_loss.description(),
            self.grad_loss.description(),
            self.deform_mse_loss.description(),
        )

    def reset(self):
        super().reset()
        self.epoch += 1
        self.pred_maps_bce_loss.reset()
        self.pred_maps_mse_loss.reset()
        self.pred_maps_dice_loss.reset()

        self.grad_loss.reset()
        self.deform_mse_loss.reset()
        self.label_dice_loss.reset()
        self.label_mse_loss.reset()
        self.template_mse_loss.reset()
        self.template_dice_loss.reset()
コード例 #5
0
ファイル: train.py プロジェクト: lisurui6/CMRSegment
def main():
    args = parse_args()
    if args.conf_path is None:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    else:
        train_conf = ConfigFactory.parse_file(str(Path(args.conf_path)))
        conf_path = Path(args.conf_path)

    if get_conf(train_conf, group="experiment",
                key="experiment_dir") is not None:
        experiment_dir = Path(
            get_conf(train_conf, group="experiment", key="experiment_dir"))
    else:
        experiment_dir = None
    config = ExperimentConfig(
        experiment_dir=experiment_dir,
        batch_size=get_conf(train_conf, group="experiment", key="batch_size"),
        num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"),
        gpu=get_conf(train_conf, group="experiment", key="gpu"),
        device=get_conf(train_conf, group="experiment", key="device"),
        num_workers=get_conf(train_conf, group="experiment",
                             key="num_workers"),
        pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"),
        n_inference=get_conf(train_conf, group="experiment",
                             key="n_inference"),
        seed=get_conf(train_conf, group="experiment", key="seed"),
    )
    augmentation_config = AugmentationConfig.from_conf(conf_path)
    shutil.copy(str(conf_path),
                str(config.experiment_dir.joinpath("train.conf")))
    network = UNet(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
    )
    training_sets, validation_sets, extra_validation_sets = construct_training_validation_dataset(
        DataConfig.from_conf(conf_path),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
        is_3d=True,
        seed=config.seed,
        augmentation_config=augmentation_config,
        output_dir=config.experiment_dir,
    )
    for train in training_sets:
        train.export(
            config.experiment_dir.joinpath("training_set_{}.csv".format(
                train.name)))
    for val in validation_sets:
        val.export(
            config.experiment_dir.joinpath("validation_set_{}.csv".format(
                val.name)))
    for val in extra_validation_sets:
        val.export(
            config.experiment_dir.joinpath(
                "extra_validation_set_{}.csv".format(val.name)))
    if get_conf(train_conf, group="optimizer", key="type") == "SGD":
        optimizer = torch.optim.SGD(
            network.parameters(),
            lr=get_conf(train_conf, group="optimizer", key="learning_rate"),
            momentum=get_conf(train_conf, group="optimizer", key="momentum"),
        )
    else:
        optimizer = torch.optim.Adam(network.parameters(),
                                     lr=get_conf(train_conf,
                                                 group="optimizer",
                                                 key="learning_rate"))
    if get_conf(train_conf, group="loss", key="type") == "FocalLoss":
        loss = FocalLoss(
            alpha=get_conf(train_conf, group="loss", key="alpha"),
            gamma=get_conf(train_conf, group="loss", key="gamma"),
            logits=True,
        )
    else:
        loss = BCELoss()
    experiment = FCN3DExperiment(
        config=config,
        network=network,
        training_sets=training_sets,
        validation_sets=validation_sets,
        extra_validation_sets=extra_validation_sets,
        optimizer=optimizer,
        loss=loss,
        other_validation_metrics=[DiceCoeffWithLogits()],
        inference_func=inference)
    experiment.train()
コード例 #6
0
class DefSegLoss(TorchLoss):
    def __init__(self, template, penalty="l2", loss_mult=None, weight=0.01):
        super().__init__()
        self.mse_loss = MSELoss()
        self.bce_loss = BCELoss(logit=True)
        self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult)
        self.weight = weight
        if isinstance(template, np.ndarray):
            self.template = torch.from_numpy(
                template).float().cuda().unsqueeze(0)
        else:
            self.template = template

    # def cumulate(
    #     self,
    #     predicted: Union[torch.Tensor, Iterable[torch.Tensor]],
    #     outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
    # ):
    #     # predicted = (warped_maps, pred maps, flow)
    #
    #     mse_loss = self.mse_loss.cumulate(predicted[0], self.template)
    #     bce_loss = self.bce_loss.cumulate(predicted[1], outputs)
    #     grad_loss = self.grad_loss.cumulate(predicted[2], None)
    #     loss = mse_loss * 0.5 + bce_loss * 0.5 + grad_loss * self.weight
    #     self._cum_loss += loss.item()
    #     self._count += 1
    #     return loss

    def cumulate(
        self,
        predicted: Union[torch.Tensor, Iterable[torch.Tensor]],
        outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
    ):
        # predicted = (warped template, pred maps, flow)

        mse_loss = self.mse_loss.cumulate(predicted[0], outputs)
        # bce_loss = self.bce_loss.cumulate(predicted[1], outputs)
        grad_loss = self.grad_loss.cumulate(predicted[2], None)
        loss = mse_loss + grad_loss * self.weight
        self._cum_loss += loss.item()
        self._count += 1
        return loss

    def new(self):
        new_loss = self.__class__(template=self.template,
                                  penalty=self.grad_loss.penalty,
                                  loss_mult=self.grad_loss.loss_mult,
                                  weight=self.weight)
        new_loss.reset()
        return new_loss

    # def description(self):
    #     return "{}, {}, {}".format(self.mse_loss.description(), self.bce_loss.description(), self.grad_loss.description())

    def description(self):
        return "{}, {}".format(self.mse_loss.description(),
                               self.grad_loss.description())

    def reset(self):
        super().reset()
        self.mse_loss.reset()
        self.bce_loss.reset()
        self.grad_loss.reset()
コード例 #7
0
class DefLoss(TorchLoss):
    def __init__(self, weights: List[float], penalty="l2", loss_mult=None):
        super().__init__()
        self.weights = weights

        self.pred_maps_bce_loss = BCELoss(logit=False)
        self.pred_maps_dice_loss = DiceLoss()
        self.pred_maps_mse_loss = MSELoss()

        self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult)
        self.deform_mse_loss = MSELoss()

        self.label_dice_loss = DiceLoss()
        self.label_mse_loss = MSELoss()

        self.label_dice_loss_affine = DiceLoss()
        self.label_mse_loss_affine = MSELoss()

        self.template_dice_loss = DiceLoss()
        self.template_mse_loss = MSELoss()

        self.atlas_mse_loss = MSELoss()
        self.atlas_mse_loss_affine = MSELoss()

        # self.label_bce_loss = BCELoss(logit=False)
        # self.template_bce_loss = BCELoss(logit=False)
        self.epoch = 0

    def cumulate(
        self,
        predicted: Union[torch.Tensor, Iterable[torch.Tensor]],
        outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
    ):
        """predicted = affine_warped_template, warped_template, pred_maps, preint_flow, warped_image, warped_label, affine_warped_label, batch_atlas"""
        label, template = outputs
        weights = self.weights
        grad_loss = self.grad_loss.cumulate(predicted[3], None)
        deform_loss = self.deform_mse_loss.cumulate(predicted[3], torch.zeros(predicted[3].shape).cuda())

        pred_map_mse_loss = self.pred_maps_mse_loss.cumulate(predicted[2], label)
        pred_map_loss = weights[0] * pred_map_mse_loss

        affine_label_mse_loss = self.label_mse_loss_affine.cumulate(predicted[0], label)
        affine_label_loss = weights[1] * affine_label_mse_loss

        label_mse_loss = self.label_mse_loss.cumulate(predicted[1], label)
        label_loss = weights[1] * label_mse_loss + affine_label_loss

        atlas_mse_loss = self.atlas_mse_loss.cumulate(predicted[5], predicted[7])
        atlas_mse_loss_affine = self.atlas_mse_loss_affine.cumulate(predicted[6], predicted[7])
        atlas_loss = atlas_mse_loss * weights[4] + atlas_mse_loss_affine * weights[4]
        loss = label_loss + grad_loss * weights[2] + deform_loss * weights[3] + pred_map_loss + atlas_loss

        # loss = label_loss
        self._cum_loss += loss.item()
        self._count += 1
        return loss

    def new(self):
        new_loss = self.__class__(
            penalty=self.grad_loss.penalty,
            loss_mult=self.grad_loss.loss_mult, weights=self.weights
        )
        new_loss.reset()
        return new_loss

    def description(self):
        return "total {:.4f}, pred map {}, label {}, grad {}, deform {}, ".format(
            self.log(),
            self.pred_maps_mse_loss.description(),
            self.label_mse_loss.description(),
            self.grad_loss.description(), self.deform_mse_loss.description(),
        )

    def reset(self):
        super().reset()
        self.epoch += 1
        self.pred_maps_bce_loss.reset()
        self.pred_maps_mse_loss.reset()
        self.pred_maps_dice_loss.reset()

        self.grad_loss.reset()
        self.deform_mse_loss.reset()
        self.label_dice_loss.reset()
        self.label_mse_loss.reset()
        self.label_dice_loss_affine.reset()
        self.label_mse_loss_affine.reset()