예제 #1
0
    def build_model(self) -> tf.keras.Model:
        """Builds segmentation model."""
        input_specs = tf.keras.layers.InputSpec(
            shape=[None] + self.task_config.model.input_size +
            [self.task_config.model.num_channels],
            dtype=self.task_config.train_data.dtype)

        l2_weight_decay = self.task_config.losses.l2_weight_decay
        # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
        # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
        # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
        l2_regularizer = (tf.keras.regularizers.l2(l2_weight_decay / 2.0)
                          if l2_weight_decay else None)

        model = factory.build_segmentation_model_3d(
            input_specs=input_specs,
            model_config=self.task_config.model,
            l2_regularizer=l2_regularizer)

        # Create a dummy input and call model instance to initialize the model. This
        # is needed when launching multiple experiments using the same model
        # directory. Since there is already a trained model, forward pass will not
        # run and the model will never be built. This is only done when spatial
        # partitioning is not enabled; otherwise it will fail with OOM due to
        # extremely large input.
        if (not self.task_config.train_input_partition_dims) and (
                not self.task_config.eval_input_partition_dims):
            dummy_input = tf.random.uniform(shape=[1] +
                                            list(input_specs.shape[1:]))
            _ = model(dummy_input)

        return model
    def _build_model(self) -> tf.keras.Model:
        """Builds and returns a segmentation model."""
        num_channels = self.params.task.model.num_channels
        input_specs = tf.keras.layers.InputSpec(
            shape=[self._batch_size] + self._input_image_size + [num_channels])

        return factory.build_segmentation_model_3d(
            input_specs=input_specs,
            model_config=self.params.task.model,
            l2_regularizer=None)
예제 #3
0
 def test_unet3d_builder(self, input_size, weight_decay):
   num_classes = 3
   input_specs = tf.keras.layers.InputSpec(
       shape=[None, input_size[0], input_size[1], input_size[2], 3])
   model_config = exp_cfg.SemanticSegmentationModel3D(num_classes=num_classes)
   l2_regularizer = (
       tf.keras.regularizers.l2(weight_decay) if weight_decay else None)
   model = factory.build_segmentation_model_3d(
       input_specs=input_specs,
       model_config=model_config,
       l2_regularizer=l2_regularizer)
   self.assertIsInstance(
       model, tf.keras.Model,
       'Output should be a tf.keras.Model instance but got %s' % type(model))