def build_basnet_model( input_specs: tf.keras.layers.InputSpec, model_config: basnet_cfg.BASNetModel, l2_regularizer: tf.keras.regularizers.Regularizer = None): """Builds BASNet model.""" backbone = backbones.factory.build_backbone( input_specs=input_specs, model_config=model_config, l2_regularizer=l2_regularizer) decoder = decoder_factory.build_decoder( input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) refinement = refunet.RefUnet() #head_config = model_config.head norm_activation_config = model_config.norm_activation """ head = segmentation_heads.SegmentationHead( num_classes=model_config.num_classes, level=head_config.level, num_convs=head_config.num_convs, num_filters=head_config.num_filters, upsample_factor=head_config.upsample_factor, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) """ model = basnet_model.BASNetModel(backbone, decoder, refinement) return model
def test_network_creation(self): """Test creation of BASNet Decoder.""" input_size = 224 tf.keras.backend.set_image_data_format('channels_last') inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1) backbone = basnet_en.BASNet_En() network = basnet_de.BASNet_De(input_specs=backbone.output_specs) module = refunet.RefUnet() endpoints = backbone(inputs) sups = network(endpoints) sups['ref'] = module(sups['7']) self.assertIn(str(ref), sups) self.assertAllEqual([1, input_size, input_size, 1], sups[str(ref)].shape.as_list())
def test_serialize_deserialize(self): # Create a network object that sets all of its config options. kwargs = dict( input_specs=layers.InputSpec(shape=[None, None, None, 1]), activation='relu', use_sync_bn=False, norm_momentum=0.99, norm_epsilon=0.001, kernel_initializer='VarianceScaling', kernel_regularizer=None, bias_regularizer=None, ) module = refunet.RefUnet() expected_config = dict(kwargs) self.assertEqual(module.get_config(), expected_config) # Create another network object from the first object's config. new_module = refunet.RefUnet.from_config(module.get_config()) # If the serialization was successful, the new config should match the old. self.assertAllEqual(module.get_config(), new_module.get_config())
# If the serialization was successful, the new config should match the old. self.assertAllEqual(module.get_config(), new_module.get_config()) """ if __name__ == '__main__': tf.test.main() """ input_size = 224 tf.keras.backend.set_image_data_format('channels_last') inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1) backbone = basnet_en.BASNet_En() network = basnet_de.BASNet_De(input_specs=backbone.output_specs) print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@") module = refunet.RefUnet() print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@") endpoints = backbone(inputs) print(endpoints) sups = network(endpoints) sups['ref'] = module(sups['7']) print(sups)
print(a) print(a.backbone.basnet_en.model_id) """ tf.keras.backend.set_image_data_format('channels_last') input_specs = tf.keras.layers.InputSpec( shape=[None] + [224, 224, 3]) backbone = backbones.BASNet_En( input_specs=input_specs) decoder = decoders.BASNet_De( input_specs=backbone.output_specs) refinement = refunet.RefUnet() backbone.model.load_weight('/home/ghpark/ckpt_basnet/ckpt-274352') decoder.load_weight('/home/ghpark/ckpt_basnet/ckpt-274352') refinement.load_weight('/home/ghpark/ckpt_basnet/ckpt-274352') model = basnet_model.BASNetModel( backbone=backbone, decoder=decoder, refinement=refinement ) #model.load_weight('/home/ghpark/ckpt_basnet/ckpt-274352')