def add_image_loss( model: tf.keras.Model, fixed_image: tf.Tensor, pred_fixed_image: tf.Tensor, loss_config: dict, ) -> tf.keras.Model: """ Add image dissimilarity loss of ddf into model. :param model: tf.keras.Model :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param pred_fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param loss_config: config for loss """ if loss_config["dissimilarity"]["image"]["weight"] > 0: loss_image = tf.reduce_mean( image_loss.dissimilarity_fn( y_true=fixed_image, y_pred=pred_fixed_image, **loss_config["dissimilarity"]["image"], )) weighted_loss_image = (loss_image * loss_config["dissimilarity"]["image"]["weight"]) model.add_loss(weighted_loss_image) model.add_metric(loss_image, name="loss/image_dissimilarity", aggregation="mean") model.add_metric( weighted_loss_image, name="loss/weighted_image_dissimilarity", aggregation="mean", ) return model
def add_label_loss( model: tf.keras.Model, grid_fixed: tf.Tensor, fixed_label: (tf.Tensor, None), pred_fixed_label: (tf.Tensor, None), loss_config: dict, ) -> tf.keras.Model: """ Add label dissimilarity loss of ddf into model. :param model: tf.keras.Model :param grid_fixed: tensor of shape (f_dim1, f_dim2, f_dim3, 3) :param fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param pred_fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param loss_config: config for loss """ if fixed_label is not None: loss_label = tf.reduce_mean( label_loss.get_dissimilarity_fn( config=loss_config["dissimilarity"]["label"])( y_true=fixed_label, y_pred=pred_fixed_label)) weighted_loss_label = (loss_label * loss_config["dissimilarity"]["label"]["weight"]) model.add_loss(weighted_loss_label) model.add_metric(loss_label, name="loss/label_dissimilarity", aggregation="mean") model.add_metric( weighted_loss_label, name="loss/weighted_label_dissimilarity", aggregation="mean", ) # metrics dice_binary = label_loss.dice_score(y_true=fixed_label, y_pred=pred_fixed_label, binary=True) dice_float = label_loss.dice_score(y_true=fixed_label, y_pred=pred_fixed_label, binary=False) tre = label_loss.compute_centroid_distance(y_true=fixed_label, y_pred=pred_fixed_label, grid=grid_fixed) foreground_label = label_loss.foreground_proportion(y=fixed_label) foreground_pred = label_loss.foreground_proportion(y=pred_fixed_label) model.add_metric(dice_binary, name="metric/dice_binary", aggregation="mean") model.add_metric(dice_float, name="metric/dice_float", aggregation="mean") model.add_metric(tre, name="metric/tre", aggregation="mean") model.add_metric(foreground_label, name="metric/foreground_label", aggregation="mean") model.add_metric(foreground_pred, name="metric/foreground_pred", aggregation="mean") return model
def add_ddf_loss(model: tf.keras.Model, ddf: tf.Tensor, loss_config: dict) -> tf.keras.Model: """ add regularization loss of ddf into model :param model: tf.keras.Model :param ddf: tensor of shape (batch, m_dim1, m_dim2, m_dim3, 3) :param loss_config: config for loss """ loss_reg = tf.reduce_mean( deform_loss.local_displacement_energy(ddf, **loss_config["regularization"])) weighted_loss_reg = loss_reg * loss_config["regularization"]["weight"] model.add_loss(weighted_loss_reg) model.add_metric(loss_reg, name="loss/regularization", aggregation="mean") model.add_metric(weighted_loss_reg, name="loss/weighted_regularization", aggregation="mean") return model
def train_and_eval(self, model: tf.keras.Model, epochs: Optional[int] = None, sparsity: Optional[float] = None): """ Trains a Keras model and returns its validation set error (1.0 - accuracy). :param model: A Keras model. :param epochs: Overrides the duration of training. :param sparsity: Desired sparsity level (for unstructured sparsity) :returns Smallest error on validation set seen during training, the error on the test set, pruned weights (if pruning was used) """ dataset = self.config.dataset batch_size = self.config.batch_size sparsity = sparsity or 0.0 train = dataset.train_dataset() \ .shuffle(batch_size * 8) \ .batch(batch_size) \ .prefetch(tf.data.experimental.AUTOTUNE) val = dataset.validation_dataset() \ .batch(batch_size) \ .prefetch(tf.data.experimental.AUTOTUNE) # TODO: check if this works, make sure we're excluding the last layer from the student if self.pruning and self.distillation: raise NotImplementedError() if self.distillation: teacher = tf.keras.models.load_model( self.distillation.distill_from) teacher._name = "teacher_" teacher.trainable = False t, a = self.distillation.temperature, self.distillation.alpha # Assemble a parallel model with the teacher and student i = tf.keras.Input(shape=dataset.input_shape) cxent = tf.keras.losses.CategoricalCrossentropy() stud_logits = model(i) tchr_logits = teacher(i) o_stud = tf.keras.layers.Softmax()(stud_logits / t) o_tchr = tf.keras.layers.Softmax()(tchr_logits / t) teaching_loss = (a * t * t) * cxent(o_tchr, o_stud) model = tf.keras.Model(inputs=i, outputs=stud_logits) model.add_loss(teaching_loss, inputs=True) if self.dataset.num_classes == 2: loss = tf.keras.losses.BinaryCrossentropy(from_logits=True) accuracy = tf.keras.metrics.BinaryAccuracy(name="accuracy") else: loss = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True) accuracy = tf.keras.metrics.SparseCategoricalAccuracy( name="accuracy") model.compile(optimizer=self.config.optimizer(), loss=loss, metrics=[accuracy]) # TODO: adjust metrics by class weight? class_weight = {k: v for k, v in enumerate(self.dataset.class_weight())} \ if self.config.use_class_weight else None epochs = epochs or self.config.epochs callbacks = self.config.callbacks() check_logs_from_epoch = 0 pruning_cb = None if self.pruning and sparsity > 0.0: assert 0.0 < sparsity <= 1.0 self.log.info(f"Target sparsity: {sparsity:.4f}") pruning_cb = DPFPruning( target_sparsity=sparsity, structured=self.pruning.structured, start_pruning_at_epoch=self.pruning.start_pruning_at_epoch, finish_pruning_by_epoch=self.pruning.finish_pruning_by_epoch) check_logs_from_epoch = self.pruning.finish_pruning_by_epoch callbacks.append(pruning_cb) log = model.fit(train, epochs=epochs, validation_data=val, verbose=1 if debug_mode() else 2, callbacks=callbacks, class_weight=class_weight) test = dataset.test_dataset() \ .batch(batch_size) \ .prefetch(tf.data.experimental.AUTOTUNE) _, test_acc = model.evaluate(test, verbose=0) return { "val_error": 1.0 - max(log.history["val_accuracy"][check_logs_from_epoch:]), "test_error": 1.0 - test_acc, "pruned_weights": pruning_cb.weights if pruning_cb else None }
def apply_kernel_regularization(func: Callable, model: tf.keras.Model): """Apply kernel regularization on all the trainable layers of a Layer or a Model""" for layer in model.layers: if hasattr(layer, 'kernel') and layer.trainable: model.add_loss(func(layer.kernel))