Ejemplo n.º 1
0
  def build_model(self) -> tf.keras.Model:
    """Builds segmentation model."""
    input_specs = tf.keras.layers.InputSpec(
        shape=[None] + self.task_config.model.input_size +
        [self.task_config.model.num_channels],
        dtype=self.task_config.train_data.dtype)

    l2_weight_decay = self.task_config.losses.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (
        tf.keras.regularizers.l2(l2_weight_decay /
                                 2.0) if l2_weight_decay else None)

    model = factory.build_segmentation_model_3d(
        input_specs=input_specs,
        model_config=self.task_config.model,
        l2_regularizer=l2_regularizer)

    # Create a dummy input and call model instance to initialize the model. This
    # is needed when launching multiple experiments using the same model
    # directory. Since there is already a trained model, forward pass will not
    # run and the model will never be built. This is only done when spatial
    # partitioning is not enabled; otherwise it will fail with OOM due to
    # extremely large input.
    if (not self.task_config.train_input_partition_dims) and (
        not self.task_config.eval_input_partition_dims):
      dummy_input = tf.random.uniform(shape=[1] + list(input_specs.shape[1:]))
      _ = model(dummy_input)

    return model
Ejemplo n.º 2
0
  def _build_model(self) -> tf.keras.Model:
    """Builds and returns a segmentation model."""
    num_channels = self.params.task.model.num_channels
    input_specs = tf.keras.layers.InputSpec(
        shape=[self._batch_size] + self._input_image_size + [num_channels])

    return factory.build_segmentation_model_3d(
        input_specs=input_specs,
        model_config=self.params.task.model,
        l2_regularizer=None)
Ejemplo n.º 3
0
 def test_unet3d_builder(self, input_size, weight_decay, use_bn):
   num_classes = 3
   input_specs = tf.keras.layers.InputSpec(
       shape=[None, input_size[0], input_size[1], input_size[2], 3])
   model_config = exp_cfg.SemanticSegmentationModel3D(num_classes=num_classes)
   model_config.head.use_batch_normalization = use_bn
   l2_regularizer = (
       tf.keras.regularizers.l2(weight_decay) if weight_decay else None)
   model = factory.build_segmentation_model_3d(
       input_specs=input_specs,
       model_config=model_config,
       l2_regularizer=l2_regularizer)
   self.assertIsInstance(
       model, tf.keras.Model,
       'Output should be a tf.keras.Model instance but got %s' % type(model))