def build_model(self) -> tf.keras.Model: """Builds segmentation model.""" input_specs = tf.keras.layers.InputSpec( shape=[None] + self.task_config.model.input_size + [self.task_config.model.num_channels], dtype=self.task_config.train_data.dtype) l2_weight_decay = self.task_config.losses.l2_weight_decay # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss. # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2) # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss) l2_regularizer = ( tf.keras.regularizers.l2(l2_weight_decay / 2.0) if l2_weight_decay else None) model = factory.build_segmentation_model_3d( input_specs=input_specs, model_config=self.task_config.model, l2_regularizer=l2_regularizer) # Create a dummy input and call model instance to initialize the model. This # is needed when launching multiple experiments using the same model # directory. Since there is already a trained model, forward pass will not # run and the model will never be built. This is only done when spatial # partitioning is not enabled; otherwise it will fail with OOM due to # extremely large input. if (not self.task_config.train_input_partition_dims) and ( not self.task_config.eval_input_partition_dims): dummy_input = tf.random.uniform(shape=[1] + list(input_specs.shape[1:])) _ = model(dummy_input) return model
def _build_model(self) -> tf.keras.Model: """Builds and returns a segmentation model.""" num_channels = self.params.task.model.num_channels input_specs = tf.keras.layers.InputSpec( shape=[self._batch_size] + self._input_image_size + [num_channels]) return factory.build_segmentation_model_3d( input_specs=input_specs, model_config=self.params.task.model, l2_regularizer=None)
def test_unet3d_builder(self, input_size, weight_decay, use_bn): num_classes = 3 input_specs = tf.keras.layers.InputSpec( shape=[None, input_size[0], input_size[1], input_size[2], 3]) model_config = exp_cfg.SemanticSegmentationModel3D(num_classes=num_classes) model_config.head.use_batch_normalization = use_bn l2_regularizer = ( tf.keras.regularizers.l2(weight_decay) if weight_decay else None) model = factory.build_segmentation_model_3d( input_specs=input_specs, model_config=model_config, l2_regularizer=l2_regularizer) self.assertIsInstance( model, tf.keras.Model, 'Output should be a tf.keras.Model instance but got %s' % type(model))