def train_step(warper, weights, optimizer, mov, fix) -> tuple: """ Train step function for backprop using gradient tape :param warper: warping function returned from layer.Warping :param weights: trainable ddf [1, f_dim1, f_dim2, f_dim3, 3] :param optimizer: tf.optimizers :param mov: moving image [1, m_dim1, m_dim2, m_dim3] :param fix: fixed image [1, f_dim1, f_dim2, f_dim3] :return: a tuple: - loss: overall loss to optimise - loss_image: image dissimilarity - loss_deform: deformation regularisation """ with tf.GradientTape() as tape: pred = warper(inputs=[weights, mov]) loss_image = REGISTRY.build_loss(config=image_loss_config)( y_true=fix, y_pred=pred, ) loss_deform = REGISTRY.build_loss(config=deform_loss_config)( inputs=weights, ) loss = loss_image + weight_deform_loss * loss_deform gradients = tape.gradient(loss, [weights]) optimizer.apply_gradients(zip(gradients, [weights])) return loss, loss_image, loss_deform
def _build_loss(self, name: str, inputs_dict: dict): """ Build and add one weighted loss together with the metrics. :param name: name of loss, image / label / regularization. :param inputs_dict: inputs for loss function """ if name not in self.config["loss"]: # loss config is not defined logger.warning( f"The configuration for loss {name} is not defined. " f"Therefore it is not used.") return loss_configs = self.config["loss"][name] if not isinstance(loss_configs, list): loss_configs = [loss_configs] for loss_config in loss_configs: if "weight" not in loss_config: # default loss weight 1 logger.warning(f"The weight for loss {name} is not defined." f"Default weight = 1.0 is used.") loss_config["weight"] = 1.0 # build loss weight = loss_config["weight"] if weight == 0: logger.warning(f"The weight for loss {name} is zero." f"Loss is not used.") return # do not perform reduction over batch axis for supporting multi-device # training, model.fit() will average over global batch size automatically loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss( config=dict_without(d=loss_config, key="weight"), default_args={"reduction": tf.keras.losses.Reduction.NONE}, ) loss_value = loss_layer(**inputs_dict) weighted_loss = loss_value * weight # add loss self._model.add_loss(weighted_loss) # add metric self._model.add_metric(loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean") self._model.add_metric( weighted_loss, name=f"loss/{name}_{loss_layer.name}_weighted", aggregation="mean", )
def _build_loss(self, name: str, inputs_dict: dict): """ Build and add one weighted loss together with the metrics. :param name: name of loss :param inputs_dict: inputs for loss function """ if name not in self.config["loss"]: # loss config is not defined logging.warning( f"The configuration for loss {name} is not defined. " f"Therefore it is not used.") return loss_configs = self.config["loss"][name] if not isinstance(loss_configs, list): loss_configs = [loss_configs] for loss_config in loss_configs: if "weight" not in loss_config: # default loss weight 1 logging.warning(f"The weight for loss {name} is not defined." f"Default weight = 1.0 is used.") loss_config["weight"] = 1.0 # build loss weight = loss_config["weight"] if weight == 0: logging.warning(f"The weight for loss {name} is zero." f"Loss is not used.") return loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss( config=dict_without(d=loss_config, key="weight")) loss_value = loss_layer(**inputs_dict) / self.global_batch_size weighted_loss = loss_value * weight # add loss self._model.add_loss(weighted_loss) # add metric self._model.add_metric(loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean") self._model.add_metric( weighted_loss, name=f"loss/{name}_{loss_layer.name}_weighted", aggregation="mean", )
def train_step(grid, weights, optimizer, mov, fix) -> object: """ Train step function for backprop using gradient tape :param grid: reference grid return from layer_util.get_reference_grid :param weights: trainable affine parameters [1, 4, 3] :param optimizer: tf.optimizers :param mov: moving image [1, m_dim1, m_dim2, m_dim3] :param fix: fixed image [1, f_dim1, f_dim2, f_dim3] :return loss: image dissimilarity to minimise """ with tf.GradientTape() as tape: pred = layer_util.resample(vol=mov, loc=layer_util.warp_grid(grid, weights)) loss = REGISTRY.build_loss(config=image_loss_config)( y_true=fix, y_pred=pred, ) gradients = tape.gradient(loss, [weights]) optimizer.apply_gradients(zip(gradients, [weights])) return loss