Beispiel #1
0
def build_basnet_model(
    input_specs: tf.keras.layers.InputSpec,
    model_config: basnet_cfg.BASNetModel,
    l2_regularizer: tf.keras.regularizers.Regularizer = None):
  """Builds BASNet model."""
  backbone = backbones.factory.build_backbone(
      input_specs=input_specs,
      model_config=model_config,
      l2_regularizer=l2_regularizer)

  decoder = decoder_factory.build_decoder(
      input_specs=backbone.output_specs,
      model_config=model_config,
      l2_regularizer=l2_regularizer)

  refinement = refunet.RefUnet()

  #head_config = model_config.head
  norm_activation_config = model_config.norm_activation
  """
  head = segmentation_heads.SegmentationHead(
      num_classes=model_config.num_classes,
      level=head_config.level,
      num_convs=head_config.num_convs,
      num_filters=head_config.num_filters,
      upsample_factor=head_config.upsample_factor,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)
  """
  model = basnet_model.BASNetModel(backbone, decoder, refinement)
  return model
Beispiel #2
0
    def test_network_creation(self):
        """Test creation of BASNet Decoder."""

        input_size = 224
        tf.keras.backend.set_image_data_format('channels_last')

        inputs = tf.keras.Input(shape=(input_size, input_size, 3),
                                batch_size=1)

        backbone = basnet_en.BASNet_En()

        network = basnet_de.BASNet_De(input_specs=backbone.output_specs)

        module = refunet.RefUnet()

        endpoints = backbone(inputs)
        sups = network(endpoints)
        sups['ref'] = module(sups['7'])
        self.assertIn(str(ref), sups)
        self.assertAllEqual([1, input_size, input_size, 1],
                            sups[str(ref)].shape.as_list())
Beispiel #3
0
    def test_serialize_deserialize(self):
        # Create a network object that sets all of its config options.
        kwargs = dict(
            input_specs=layers.InputSpec(shape=[None, None, None, 1]),
            activation='relu',
            use_sync_bn=False,
            norm_momentum=0.99,
            norm_epsilon=0.001,
            kernel_initializer='VarianceScaling',
            kernel_regularizer=None,
            bias_regularizer=None,
        )
        module = refunet.RefUnet()

        expected_config = dict(kwargs)
        self.assertEqual(module.get_config(), expected_config)

        # Create another network object from the first object's config.
        new_module = refunet.RefUnet.from_config(module.get_config())

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(module.get_config(), new_module.get_config())
Beispiel #4
0
        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(module.get_config(), new_module.get_config())


"""
if __name__ == '__main__':
  tf.test.main()
"""

input_size = 224
tf.keras.backend.set_image_data_format('channels_last')

inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)

backbone = basnet_en.BASNet_En()

network = basnet_de.BASNet_De(input_specs=backbone.output_specs)
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
module = refunet.RefUnet()

print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
endpoints = backbone(inputs)

print(endpoints)

sups = network(endpoints)

sups['ref'] = module(sups['7'])

print(sups)
Beispiel #5
0
print(a)
print(a.backbone.basnet_en.model_id)
"""


tf.keras.backend.set_image_data_format('channels_last')

input_specs = tf.keras.layers.InputSpec(
    shape=[None] + [224, 224, 3])

backbone = backbones.BASNet_En(
    input_specs=input_specs)
decoder = decoders.BASNet_De(
    input_specs=backbone.output_specs)
refinement = refunet.RefUnet()


backbone.model.load_weight('/home/ghpark/ckpt_basnet/ckpt-274352')
decoder.load_weight('/home/ghpark/ckpt_basnet/ckpt-274352')
refinement.load_weight('/home/ghpark/ckpt_basnet/ckpt-274352')


model = basnet_model.BASNetModel(
    backbone=backbone,
    decoder=decoder,
    refinement=refinement
)

#model.load_weight('/home/ghpark/ckpt_basnet/ckpt-274352')