def compile_model(self, optimizer, optimizer_kwargs, loss, metrics, **kwargs): """ Compile the stored tf.keras Model instance stored in self.model Sets the loss function, optimizer and metrics Args: optimizer: (string) The name of a tf.keras.optimizers Optimizer optimizer_kwargs: (dict) Key-word arguments passed to the Optimizer loss: (string) The name of a tf.keras.losses or MultiPlanarUnet loss function metrics: (list) List of tf.keras.metrics or MultiPlanarUNet metrics. **kwargs: (dict) Key-word arguments passed to losses and/or metrics that accept such. """ # Make sure sparse metrics and loss are specified as sparse metrics = ensure_list_or_tuple(metrics) losses = ensure_list_or_tuple(loss) ensure_sparse(metrics + losses) # Initialize optimizer optimizer = optimizers.__dict__[optimizer] optimizer = optimizer(**optimizer_kwargs) # Initialize loss(es) and metrics from tf.keras or MultiPlanarUNet losses = init_losses(losses, self.logger, **kwargs) metrics = init_metrics(metrics, self.logger, **kwargs) # Compile the model self.model.compile(optimizer=optimizer, loss=losses, metrics=metrics) self.logger("Optimizer: %s" % optimizer) self.logger("Loss funcs: %s" % losses) self.logger("Metrics: %s" % init_metrics) return self
def compile_model(self, optimizer, loss, metrics, reduction, check_sparse=False, optimizer_kwargs={}, loss_kwargs={}, **kwargs): """ Compile the stored tf.keras Model instance stored in self.model Sets the loss function, optimizer and metrics Args: optimizer: (string) The name of a tf.keras.optimizers Optimizer optimizer_kwargs: (dict) Key-word arguments passed to the Optimizer loss: (string) The name of a tf.keras.losses or MultiPlanarUnet loss function metrics: (list) List of tf.keras.metrics or mpunet metrics. reduction: TODO check_sparse: TODO **kwargs: (dict) Key-word arguments passed to losses and/or metrics that accept such. """ # Make sure sparse metrics and loss are specified as sparse metrics = ensure_list_or_tuple(metrics) losses = ensure_list_or_tuple(loss) if check_sparse: ensure_sparse(metrics + losses) # Initialize optimizer, loss(es) and metric(s) from tf.keras or # mpunet optimizer = init_optimizer(optimizer, self.logger, **optimizer_kwargs) losses = init_losses(losses, self.logger, **kwargs) for i, loss in enumerate(losses): try: losses[i] = loss(reduction=reduction, **loss_kwargs) except (ValueError, TypeError): raise TypeError("All loss functions must currently be " "callable and accept the 'reduction' " "parameter specifying a " "tf.keras.losses.Reduction type. If you " "specified a keras loss function such as " "'sparse_categorical_crossentropy', change " "this to its corresponding loss class " "'SparseCategoricalCrossentropy'. If " "you implemented a custom loss function, " "please raise an issue on GitHub.") metrics = init_metrics(metrics, self.logger, **kwargs) # Compile the model self.model.compile(optimizer=optimizer, loss=losses, metrics=metrics) self.logger("Optimizer: %s" % optimizer) self.logger("Loss funcs: %s" % losses) self.logger("Metrics: %s" % init_metrics) return self