Beispiel #1
0
    def test_unet_3d_decoder_creation(self, model_id):
        """Test creation of UNet 3D decoder."""
        # Create test input for decoders based on input model_id.
        input_specs = {}
        for level in range(model_id):
            input_specs[str(level + 1)] = tf.TensorShape([
                1, 128 // (2**level), 128 // (2**level), 128 // (2**level), 1
            ])

        network = decoders.UNet3DDecoder(model_id=model_id,
                                         input_specs=input_specs,
                                         use_sync_bn=True,
                                         use_batch_normalization=True,
                                         use_deconvolution=True)

        model_config = semantic_segmentation_3d_exp.SemanticSegmentationModel3D(
        )
        model_config.num_classes = 2
        model_config.num_channels = 1
        model_config.input_size = [None, None, None]
        model_config.decoder = decoders_cfg.Decoder(
            type='unet_3d_decoder',
            unet_3d_decoder=decoders_cfg.UNet3DDecoder(model_id=model_id))

        factory_network = factory.build_decoder(input_specs=input_specs,
                                                model_config=model_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()
        print(network_config)
        print(factory_network_config)

        self.assertEqual(network_config, factory_network_config)
Beispiel #2
0
  def test_serialize_deserialize(self):
    """Validate the network can be serialized and deserialized."""
    num_classes = 3
    backbone = backbones.UNet3D(model_id=4)
    decoder = decoders.UNet3DDecoder(
        model_id=4, input_specs=backbone.output_specs)
    head = segmentation_heads_3d.SegmentationHead3D(
        num_classes, level=1, num_convs=0)
    model = segmentation_model.SegmentationModel(
        backbone=backbone, decoder=decoder, head=head)

    config = model.get_config()
    new_model = segmentation_model.SegmentationModel.from_config(config)

    # Validate that the config can be forced to JSON.
    _ = new_model.to_json()

    # If the serialization was successful, the new config should match the old.
    self.assertAllEqual(model.get_config(), new_model.get_config())
Beispiel #3
0
  def test_segmentation_network_unet3d_creation(self, input_size, depth):
    """Test for creation of a segmentation network."""
    num_classes = 2
    inputs = np.random.rand(2, input_size[0], input_size[0], input_size[1], 3)
    tf.keras.backend.set_image_data_format('channels_last')
    backbone = backbones.UNet3D(model_id=depth)

    decoder = decoders.UNet3DDecoder(
        model_id=depth, input_specs=backbone.output_specs)
    head = segmentation_heads_3d.SegmentationHead3D(
        num_classes, level=1, num_convs=0)

    model = segmentation_model.SegmentationModel(
        backbone=backbone, decoder=decoder, head=head)

    outputs = model(inputs)
    self.assertAllEqual(
        [2, input_size[0], input_size[0], input_size[1], num_classes],
        outputs['logits'].numpy().shape)