Example #1
0
    def test_serialize_deserialize(self):
        # Create a network object that sets all of its config options.
        kwargs = dict(model_id=4,
                      pool_size=(2, 2, 2),
                      kernel_size=(3, 3, 3),
                      activation='relu',
                      base_filters=32,
                      kernel_regularizer=None,
                      norm_momentum=0.99,
                      norm_epsilon=0.001,
                      use_sync_bn=False,
                      use_batch_normalization=True)
        network = unet_3d.UNet3D(**kwargs)

        expected_config = dict(kwargs)
        self.assertEqual(network.get_config(), expected_config)

        # Create another network object from the first object's config.
        new_network = unet_3d.UNet3D.from_config(network.get_config())

        # Validate that the config can be forced to JSON.
        _ = new_network.to_json()

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(network.get_config(), new_network.get_config())
  def test_network_creation(self, input_size, model_id):
    """Test creation of UNet3D family models."""
    tf.keras.backend.set_image_data_format('channels_last')
    network = unet_3d.UNet3D(model_id=model_id)
    inputs = tf.keras.Input(
        shape=(input_size[0], input_size[0], input_size[1], 3), batch_size=1)
    endpoints = network(inputs)

    for layer_depth in range(model_id):
      self.assertAllEqual([
          1, input_size[0] / 2**layer_depth, input_size[0] / 2**layer_depth,
          input_size[1] / 2**layer_depth, 64 * 2**layer_depth
      ], endpoints[str(layer_depth + 1)].shape.as_list())
  def test_network_creation(self, input_size, model_id):
    """Test creation of UNet3D family models."""
    tf.keras.backend.set_image_data_format('channels_last')

    # `input_size` consists of [spatial size, volume size].
    inputs = tf.keras.Input(
        shape=(input_size[0], input_size[0], input_size[1], 3), batch_size=1)
    backbone = unet_3d.UNet3D(model_id=model_id)
    network = unet_3d_decoder.UNet3DDecoder(
        model_id=model_id, input_specs=backbone.output_specs)

    endpoints = backbone(inputs)
    feats = network(endpoints)

    self.assertIn('1', feats)
    self.assertAllEqual([1, input_size[0], input_size[0], input_size[1], 64],
                        feats['1'].shape.as_list())