def __init__(self, weights: List[float], penalty="l2", loss_mult=None): super().__init__() self.weights = weights self.pred_maps_bce_loss = BCELoss(logit=False) self.pred_maps_dice_loss = DiceLoss() self.pred_maps_mse_loss = MSELoss() self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult) self.deform_mse_loss = MSELoss() self.label_dice_loss = DiceLoss() self.label_mse_loss = MSELoss() self.label_dice_loss_affine = DiceLoss() self.label_mse_loss_affine = MSELoss() self.template_dice_loss = DiceLoss() self.template_mse_loss = MSELoss() self.atlas_mse_loss = MSELoss() self.atlas_mse_loss_affine = MSELoss() # self.label_bce_loss = BCELoss(logit=False) # self.template_bce_loss = BCELoss(logit=False) self.epoch = 0
def __init__(self, template, penalty="l2", loss_mult=None, weight=0.01): super().__init__() self.mse_loss = MSELoss() self.bce_loss = BCELoss(logit=True) self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult) self.weight = weight if isinstance(template, np.ndarray): self.template = torch.from_numpy( template).float().cuda().unsqueeze(0) else: self.template = template
def main(): args = parse_args() if args.conf_path is None: train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH)) else: train_conf = ConfigFactory.parse_file(str(Path(args.conf_path))) if get_conf(train_conf, group="experiment", key="experiment_dir") is not None: experiment_dir = Path( get_conf(train_conf, group="experiment", key="experiment_dir")) else: experiment_dir = None config = ExperimentConfig( experiment_dir=experiment_dir, batch_size=get_conf(train_conf, group="experiment", key="batch_size"), num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"), gpu=get_conf(train_conf, group="experiment", key="gpu"), device=get_conf(train_conf, group="experiment", key="device"), num_workers=get_conf(train_conf, group="experiment", key="num_workers"), pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"), ) shutil.copy(str(TRAIN_CONF_PATH), str(config.experiment_dir.joinpath("train.conf"))) network = FCN2DSegmentationModel( in_channels=get_conf(train_conf, group="network", key="in_channels"), n_classes=get_conf(train_conf, group="network", key="n_classes"), n_filters=get_conf(train_conf, group="network", key="n_filters"), up_conv_filter=get_conf(train_conf, group="network", key="up_conv_filter"), final_conv_filter=get_conf(train_conf, group="network", key="final_conv_filter"), feature_size=get_conf(train_conf, group="network", key="feature_size")) training_set, validation_set = construct_training_validation_dataset( DataConfig.from_conf(TRAIN_CONF_PATH), feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="in_channels")) training_set.export(config.experiment_dir.joinpath("training_set.csv")) validation_set.export(config.experiment_dir.joinpath("validation_set.csv")) if get_conf(train_conf, group="optimizer", key="type") == "SGD": optimizer = torch.optim.SGD( network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate"), momentum=get_conf(train_conf, group="optimizer", key="momentum"), ) else: optimizer = torch.optim.Adam(network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate")) if get_conf(train_conf, group="loss", key="type") == "FocalLoss": loss = FocalLoss( alpha=get_conf(train_conf, group="loss", key="alpha"), gamma=get_conf(train_conf, group="loss", key="gamma"), logits=True, ) else: loss = BCELoss() experiment = Experiment( config=config, network=network, training_set=training_set, validation_set=validation_set, optimizer=optimizer, loss=loss, other_validation_metrics=[DiceCoeffWithLogits()], ) experiment.train()
class DefSegLoss(TorchLoss): def __init__(self, weights: List[float], penalty="l2", loss_mult=None): super().__init__() self.weights = weights self.pred_maps_bce_loss = BCELoss(logit=False) self.pred_maps_dice_loss = DiceLoss() self.pred_maps_mse_loss = MSELoss() self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult) self.deform_mse_loss = MSELoss() self.label_dice_loss = DiceLoss() self.label_mse_loss = MSELoss() self.template_dice_loss = DiceLoss() self.template_mse_loss = MSELoss() # self.label_bce_loss = BCELoss(logit=False) # self.template_bce_loss = BCELoss(logit=False) self.epoch = 0 def cumulate( self, predicted: Union[torch.Tensor, Iterable[torch.Tensor]], outputs: Union[torch.Tensor, Iterable[torch.Tensor]], ): """predicted = (warped template, warped maps, pred maps, flow)""" label, template = outputs if self.epoch <= 10: weights = [1, 0, 0, 0, 0, 0, 0, 0, 0] else: weights = self.weights pred_map_bce_loss = self.pred_maps_bce_loss.cumulate( predicted[2], label) pred_map_dice_loss = self.pred_maps_dice_loss.cumulate( predicted[2], label) pred_map_mse_loss = self.pred_maps_mse_loss.cumulate( predicted[2], label) pred_map_loss = weights[0] * pred_map_bce_loss + weights[ 1] * pred_map_dice_loss + weights[2] * pred_map_mse_loss grad_loss = self.grad_loss.cumulate(predicted[3], None) deform_loss = self.deform_mse_loss.cumulate( predicted[3], torch.zeros(predicted[3].shape).cuda()) label_dice_loss = self.label_dice_loss.cumulate(predicted[0], label) label_mse_loss = self.label_mse_loss.cumulate(predicted[0], label) label_loss = weights[3] * label_dice_loss + weights[4] * label_mse_loss template_dice_loss = self.template_dice_loss.cumulate( predicted[1], template) template_mse_loss = self.template_mse_loss.cumulate( predicted[1], template) template_loss = weights[5] * template_dice_loss + weights[ 6] * template_mse_loss loss = pred_map_loss + label_loss + template_loss + grad_loss * self.weights[ 7] + deform_loss * self.weights[8] self._cum_loss += loss.item() self._count += 1 return loss def new(self): new_loss = self.__class__(penalty=self.grad_loss.penalty, loss_mult=self.grad_loss.loss_mult, weights=self.weights) new_loss.reset() return new_loss def description(self): return "total {:.4f}, pred map {}, pred map {}, pred map {}, label {}, label {}, grad {}, deform {}, ".format( self.log(), self.pred_maps_bce_loss.description(), self.pred_maps_dice_loss.description(), self.pred_maps_mse_loss.description(), self.label_dice_loss.description(), self.label_mse_loss.description(), # self.template_dice_loss.description(), self.template_mse_loss.description(), self.grad_loss.description(), self.deform_mse_loss.description(), ) def reset(self): super().reset() self.epoch += 1 self.pred_maps_bce_loss.reset() self.pred_maps_mse_loss.reset() self.pred_maps_dice_loss.reset() self.grad_loss.reset() self.deform_mse_loss.reset() self.label_dice_loss.reset() self.label_mse_loss.reset() self.template_mse_loss.reset() self.template_dice_loss.reset()
def main(): args = parse_args() if args.conf_path is None: train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH)) conf_path = TRAIN_CONF_PATH else: train_conf = ConfigFactory.parse_file(str(Path(args.conf_path))) conf_path = Path(args.conf_path) if get_conf(train_conf, group="experiment", key="experiment_dir") is not None: experiment_dir = Path( get_conf(train_conf, group="experiment", key="experiment_dir")) else: experiment_dir = None config = ExperimentConfig( experiment_dir=experiment_dir, batch_size=get_conf(train_conf, group="experiment", key="batch_size"), num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"), gpu=get_conf(train_conf, group="experiment", key="gpu"), device=get_conf(train_conf, group="experiment", key="device"), num_workers=get_conf(train_conf, group="experiment", key="num_workers"), pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"), n_inference=get_conf(train_conf, group="experiment", key="n_inference"), seed=get_conf(train_conf, group="experiment", key="seed"), ) augmentation_config = AugmentationConfig.from_conf(conf_path) shutil.copy(str(conf_path), str(config.experiment_dir.joinpath("train.conf"))) network = UNet( in_channels=get_conf(train_conf, group="network", key="in_channels"), n_classes=get_conf(train_conf, group="network", key="n_classes"), n_filters=get_conf(train_conf, group="network", key="n_filters"), ) training_sets, validation_sets, extra_validation_sets = construct_training_validation_dataset( DataConfig.from_conf(conf_path), feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="n_slices"), is_3d=True, seed=config.seed, augmentation_config=augmentation_config, output_dir=config.experiment_dir, ) for train in training_sets: train.export( config.experiment_dir.joinpath("training_set_{}.csv".format( train.name))) for val in validation_sets: val.export( config.experiment_dir.joinpath("validation_set_{}.csv".format( val.name))) for val in extra_validation_sets: val.export( config.experiment_dir.joinpath( "extra_validation_set_{}.csv".format(val.name))) if get_conf(train_conf, group="optimizer", key="type") == "SGD": optimizer = torch.optim.SGD( network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate"), momentum=get_conf(train_conf, group="optimizer", key="momentum"), ) else: optimizer = torch.optim.Adam(network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate")) if get_conf(train_conf, group="loss", key="type") == "FocalLoss": loss = FocalLoss( alpha=get_conf(train_conf, group="loss", key="alpha"), gamma=get_conf(train_conf, group="loss", key="gamma"), logits=True, ) else: loss = BCELoss() experiment = FCN3DExperiment( config=config, network=network, training_sets=training_sets, validation_sets=validation_sets, extra_validation_sets=extra_validation_sets, optimizer=optimizer, loss=loss, other_validation_metrics=[DiceCoeffWithLogits()], inference_func=inference) experiment.train()
class DefSegLoss(TorchLoss): def __init__(self, template, penalty="l2", loss_mult=None, weight=0.01): super().__init__() self.mse_loss = MSELoss() self.bce_loss = BCELoss(logit=True) self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult) self.weight = weight if isinstance(template, np.ndarray): self.template = torch.from_numpy( template).float().cuda().unsqueeze(0) else: self.template = template # def cumulate( # self, # predicted: Union[torch.Tensor, Iterable[torch.Tensor]], # outputs: Union[torch.Tensor, Iterable[torch.Tensor]], # ): # # predicted = (warped_maps, pred maps, flow) # # mse_loss = self.mse_loss.cumulate(predicted[0], self.template) # bce_loss = self.bce_loss.cumulate(predicted[1], outputs) # grad_loss = self.grad_loss.cumulate(predicted[2], None) # loss = mse_loss * 0.5 + bce_loss * 0.5 + grad_loss * self.weight # self._cum_loss += loss.item() # self._count += 1 # return loss def cumulate( self, predicted: Union[torch.Tensor, Iterable[torch.Tensor]], outputs: Union[torch.Tensor, Iterable[torch.Tensor]], ): # predicted = (warped template, pred maps, flow) mse_loss = self.mse_loss.cumulate(predicted[0], outputs) # bce_loss = self.bce_loss.cumulate(predicted[1], outputs) grad_loss = self.grad_loss.cumulate(predicted[2], None) loss = mse_loss + grad_loss * self.weight self._cum_loss += loss.item() self._count += 1 return loss def new(self): new_loss = self.__class__(template=self.template, penalty=self.grad_loss.penalty, loss_mult=self.grad_loss.loss_mult, weight=self.weight) new_loss.reset() return new_loss # def description(self): # return "{}, {}, {}".format(self.mse_loss.description(), self.bce_loss.description(), self.grad_loss.description()) def description(self): return "{}, {}".format(self.mse_loss.description(), self.grad_loss.description()) def reset(self): super().reset() self.mse_loss.reset() self.bce_loss.reset() self.grad_loss.reset()
class DefLoss(TorchLoss): def __init__(self, weights: List[float], penalty="l2", loss_mult=None): super().__init__() self.weights = weights self.pred_maps_bce_loss = BCELoss(logit=False) self.pred_maps_dice_loss = DiceLoss() self.pred_maps_mse_loss = MSELoss() self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult) self.deform_mse_loss = MSELoss() self.label_dice_loss = DiceLoss() self.label_mse_loss = MSELoss() self.label_dice_loss_affine = DiceLoss() self.label_mse_loss_affine = MSELoss() self.template_dice_loss = DiceLoss() self.template_mse_loss = MSELoss() self.atlas_mse_loss = MSELoss() self.atlas_mse_loss_affine = MSELoss() # self.label_bce_loss = BCELoss(logit=False) # self.template_bce_loss = BCELoss(logit=False) self.epoch = 0 def cumulate( self, predicted: Union[torch.Tensor, Iterable[torch.Tensor]], outputs: Union[torch.Tensor, Iterable[torch.Tensor]], ): """predicted = affine_warped_template, warped_template, pred_maps, preint_flow, warped_image, warped_label, affine_warped_label, batch_atlas""" label, template = outputs weights = self.weights grad_loss = self.grad_loss.cumulate(predicted[3], None) deform_loss = self.deform_mse_loss.cumulate(predicted[3], torch.zeros(predicted[3].shape).cuda()) pred_map_mse_loss = self.pred_maps_mse_loss.cumulate(predicted[2], label) pred_map_loss = weights[0] * pred_map_mse_loss affine_label_mse_loss = self.label_mse_loss_affine.cumulate(predicted[0], label) affine_label_loss = weights[1] * affine_label_mse_loss label_mse_loss = self.label_mse_loss.cumulate(predicted[1], label) label_loss = weights[1] * label_mse_loss + affine_label_loss atlas_mse_loss = self.atlas_mse_loss.cumulate(predicted[5], predicted[7]) atlas_mse_loss_affine = self.atlas_mse_loss_affine.cumulate(predicted[6], predicted[7]) atlas_loss = atlas_mse_loss * weights[4] + atlas_mse_loss_affine * weights[4] loss = label_loss + grad_loss * weights[2] + deform_loss * weights[3] + pred_map_loss + atlas_loss # loss = label_loss self._cum_loss += loss.item() self._count += 1 return loss def new(self): new_loss = self.__class__( penalty=self.grad_loss.penalty, loss_mult=self.grad_loss.loss_mult, weights=self.weights ) new_loss.reset() return new_loss def description(self): return "total {:.4f}, pred map {}, label {}, grad {}, deform {}, ".format( self.log(), self.pred_maps_mse_loss.description(), self.label_mse_loss.description(), self.grad_loss.description(), self.deform_mse_loss.description(), ) def reset(self): super().reset() self.epoch += 1 self.pred_maps_bce_loss.reset() self.pred_maps_mse_loss.reset() self.pred_maps_dice_loss.reset() self.grad_loss.reset() self.deform_mse_loss.reset() self.label_dice_loss.reset() self.label_mse_loss.reset() self.label_dice_loss_affine.reset() self.label_mse_loss_affine.reset()