Exemple #1
0
def add_image_loss(
    model: tf.keras.Model,
    fixed_image: tf.Tensor,
    pred_fixed_image: tf.Tensor,
    loss_config: dict,
) -> tf.keras.Model:
    """
    Add image dissimilarity loss of ddf into model.

    :param model: tf.keras.Model
    :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param pred_fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param loss_config: config for loss
    """
    if loss_config["dissimilarity"]["image"]["weight"] > 0:
        loss_image = tf.reduce_mean(
            image_loss.dissimilarity_fn(
                y_true=fixed_image,
                y_pred=pred_fixed_image,
                **loss_config["dissimilarity"]["image"],
            ))
        weighted_loss_image = (loss_image *
                               loss_config["dissimilarity"]["image"]["weight"])
        model.add_loss(weighted_loss_image)
        model.add_metric(loss_image,
                         name="loss/image_dissimilarity",
                         aggregation="mean")
        model.add_metric(
            weighted_loss_image,
            name="loss/weighted_image_dissimilarity",
            aggregation="mean",
        )
    return model
Exemple #2
0
def add_label_loss(
    model: tf.keras.Model,
    grid_fixed: tf.Tensor,
    fixed_label: (tf.Tensor, None),
    pred_fixed_label: (tf.Tensor, None),
    loss_config: dict,
) -> tf.keras.Model:
    """
    Add label dissimilarity loss of ddf into model.

    :param model: tf.keras.Model
    :param grid_fixed: tensor of shape (f_dim1, f_dim2, f_dim3, 3)
    :param fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param pred_fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param loss_config: config for loss
    """
    if fixed_label is not None:
        loss_label = tf.reduce_mean(
            label_loss.get_dissimilarity_fn(
                config=loss_config["dissimilarity"]["label"])(
                    y_true=fixed_label, y_pred=pred_fixed_label))
        weighted_loss_label = (loss_label *
                               loss_config["dissimilarity"]["label"]["weight"])
        model.add_loss(weighted_loss_label)
        model.add_metric(loss_label,
                         name="loss/label_dissimilarity",
                         aggregation="mean")
        model.add_metric(
            weighted_loss_label,
            name="loss/weighted_label_dissimilarity",
            aggregation="mean",
        )

        # metrics
        dice_binary = label_loss.dice_score(y_true=fixed_label,
                                            y_pred=pred_fixed_label,
                                            binary=True)
        dice_float = label_loss.dice_score(y_true=fixed_label,
                                           y_pred=pred_fixed_label,
                                           binary=False)
        tre = label_loss.compute_centroid_distance(y_true=fixed_label,
                                                   y_pred=pred_fixed_label,
                                                   grid=grid_fixed)
        foreground_label = label_loss.foreground_proportion(y=fixed_label)
        foreground_pred = label_loss.foreground_proportion(y=pred_fixed_label)
        model.add_metric(dice_binary,
                         name="metric/dice_binary",
                         aggregation="mean")
        model.add_metric(dice_float,
                         name="metric/dice_float",
                         aggregation="mean")
        model.add_metric(tre, name="metric/tre", aggregation="mean")
        model.add_metric(foreground_label,
                         name="metric/foreground_label",
                         aggregation="mean")
        model.add_metric(foreground_pred,
                         name="metric/foreground_pred",
                         aggregation="mean")
    return model
def add_ddf_loss(model: tf.keras.Model, ddf: tf.Tensor,
                 loss_config: dict) -> tf.keras.Model:
    """
    add regularization loss of ddf into model
    :param model: tf.keras.Model
    :param ddf: tensor of shape (batch, m_dim1, m_dim2, m_dim3, 3)
    :param loss_config: config for loss
    """
    loss_reg = tf.reduce_mean(
        deform_loss.local_displacement_energy(ddf,
                                              **loss_config["regularization"]))
    weighted_loss_reg = loss_reg * loss_config["regularization"]["weight"]
    model.add_loss(weighted_loss_reg)
    model.add_metric(loss_reg, name="loss/regularization", aggregation="mean")
    model.add_metric(weighted_loss_reg,
                     name="loss/weighted_regularization",
                     aggregation="mean")
    return model
    def train_and_eval(self,
                       model: tf.keras.Model,
                       epochs: Optional[int] = None,
                       sparsity: Optional[float] = None):
        """
        Trains a Keras model and returns its validation set error (1.0 - accuracy).
        :param model: A Keras model.
        :param epochs: Overrides the duration of training.
        :param sparsity: Desired sparsity level (for unstructured sparsity)
        :returns Smallest error on validation set seen during training, the error on the test set,
        pruned weights (if pruning was used)
        """
        dataset = self.config.dataset
        batch_size = self.config.batch_size
        sparsity = sparsity or 0.0

        train = dataset.train_dataset() \
            .shuffle(batch_size * 8) \
            .batch(batch_size) \
            .prefetch(tf.data.experimental.AUTOTUNE)

        val = dataset.validation_dataset() \
            .batch(batch_size) \
            .prefetch(tf.data.experimental.AUTOTUNE)

        # TODO: check if this works, make sure we're excluding the last layer from the student
        if self.pruning and self.distillation:
            raise NotImplementedError()

        if self.distillation:
            teacher = tf.keras.models.load_model(
                self.distillation.distill_from)
            teacher._name = "teacher_"
            teacher.trainable = False

            t, a = self.distillation.temperature, self.distillation.alpha

            # Assemble a parallel model with the teacher and student
            i = tf.keras.Input(shape=dataset.input_shape)
            cxent = tf.keras.losses.CategoricalCrossentropy()

            stud_logits = model(i)
            tchr_logits = teacher(i)

            o_stud = tf.keras.layers.Softmax()(stud_logits / t)
            o_tchr = tf.keras.layers.Softmax()(tchr_logits / t)
            teaching_loss = (a * t * t) * cxent(o_tchr, o_stud)

            model = tf.keras.Model(inputs=i, outputs=stud_logits)
            model.add_loss(teaching_loss, inputs=True)

        if self.dataset.num_classes == 2:
            loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
            accuracy = tf.keras.metrics.BinaryAccuracy(name="accuracy")
        else:
            loss = tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True)
            accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                name="accuracy")
        model.compile(optimizer=self.config.optimizer(),
                      loss=loss,
                      metrics=[accuracy])

        # TODO: adjust metrics by class weight?
        class_weight = {k: v for k, v in enumerate(self.dataset.class_weight())} \
            if self.config.use_class_weight else None
        epochs = epochs or self.config.epochs
        callbacks = self.config.callbacks()
        check_logs_from_epoch = 0

        pruning_cb = None
        if self.pruning and sparsity > 0.0:
            assert 0.0 < sparsity <= 1.0
            self.log.info(f"Target sparsity: {sparsity:.4f}")
            pruning_cb = DPFPruning(
                target_sparsity=sparsity,
                structured=self.pruning.structured,
                start_pruning_at_epoch=self.pruning.start_pruning_at_epoch,
                finish_pruning_by_epoch=self.pruning.finish_pruning_by_epoch)
            check_logs_from_epoch = self.pruning.finish_pruning_by_epoch
            callbacks.append(pruning_cb)

        log = model.fit(train,
                        epochs=epochs,
                        validation_data=val,
                        verbose=1 if debug_mode() else 2,
                        callbacks=callbacks,
                        class_weight=class_weight)

        test = dataset.test_dataset() \
            .batch(batch_size) \
            .prefetch(tf.data.experimental.AUTOTUNE)
        _, test_acc = model.evaluate(test, verbose=0)

        return {
            "val_error":
            1.0 - max(log.history["val_accuracy"][check_logs_from_epoch:]),
            "test_error":
            1.0 - test_acc,
            "pruned_weights":
            pruning_cb.weights if pruning_cb else None
        }
Exemple #5
0
def apply_kernel_regularization(func: Callable, model: tf.keras.Model):
    """Apply kernel regularization on all the trainable layers of a Layer or a Model"""
    for layer in model.layers:
        if hasattr(layer, 'kernel') and layer.trainable:
            model.add_loss(func(layer.kernel))