Пример #1
0
def build_basnet_model(
        input_specs: tf.keras.layers.InputSpec,
        model_config: exp_cfg.BASNetModel,
        l2_regularizer: tf.keras.regularizers.Regularizer = None):
    """Builds BASNet model."""
    norm_activation_config = model_config.norm_activation
    backbone = basnet_model.BASNetEncoder(
        input_specs=input_specs,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        use_bias=model_config.use_bias,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    decoder = basnet_model.BASNetDecoder(
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        use_bias=model_config.use_bias,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    refinement = refunet.RefUnet(
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        use_bias=model_config.use_bias,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    model = basnet_model.BASNetModel(backbone, decoder, refinement)
    return model
    def test_basnet_network_creation(self, input_size):
        """Test for creation of a segmentation network."""
        inputs = np.random.rand(2, input_size, input_size, 3)
        tf.keras.backend.set_image_data_format('channels_last')

        backbone = basnet_model.BASNetEncoder()
        decoder = basnet_model.BASNetDecoder()
        refinement = refunet.RefUnet()

        model = basnet_model.BASNetModel(backbone=backbone,
                                         decoder=decoder,
                                         refinement=refinement)

        sigmoids = model(inputs)
        levels = sorted(sigmoids.keys())
        self.assertAllEqual([2, input_size, input_size, 1],
                            sigmoids[levels[-1]].numpy().shape)
    def test_serialize_deserialize(self):
        """Validate the network can be serialized and deserialized."""
        backbone = basnet_model.BASNetEncoder()
        decoder = basnet_model.BASNetDecoder()
        refinement = refunet.RefUnet()

        model = basnet_model.BASNetModel(backbone=backbone,
                                         decoder=decoder,
                                         refinement=refinement)

        config = model.get_config()
        new_model = basnet_model.BASNetModel.from_config(config)

        # Validate that the config can be forced to JSON.
        _ = new_model.to_json()

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(model.get_config(), new_model.get_config())