Esempio n. 1
0
  def test_model_creation(self, project_dim, num_proj_layers, ft_proj_idx):
    input_size = 224
    inputs = np.random.rand(2, input_size, input_size, 3)
    input_specs = tf.keras.layers.InputSpec(
        shape=[None, input_size, input_size, 3])

    tf.keras.backend.set_image_data_format('channels_last')

    backbone = backbones.ResNet(model_id=50, activation='relu',
                                input_specs=input_specs)
    projection_head = simclr_head.ProjectionHead(
        proj_output_dim=project_dim,
        num_proj_layers=num_proj_layers,
        ft_proj_idx=ft_proj_idx
    )
    num_classes = 10
    supervised_head = simclr_head.ClassificationHead(
        num_classes=10
    )

    model = simclr_model.SimCLRModel(
        input_specs=input_specs,
        backbone=backbone,
        projection_head=projection_head,
        supervised_head=supervised_head,
        mode=simclr_model.PRETRAIN
    )
    outputs = model(inputs)
    projection_outputs = outputs[simclr_model.PROJECTION_OUTPUT_KEY]
    supervised_outputs = outputs[simclr_model.SUPERVISED_OUTPUT_KEY]

    self.assertAllEqual(projection_outputs.shape.as_list(),
                        [2, project_dim])
    self.assertAllEqual([2, num_classes],
                        supervised_outputs.numpy().shape)
Esempio n. 2
0
    def test_resnet_creation(self, model_id):
        """Test creation of ResNet models."""

        network = backbones.ResNet(model_id=model_id,
                                   se_ratio=0.0,
                                   norm_momentum=0.99,
                                   norm_epsilon=1e-5)

        backbone_config = backbones_cfg.Backbone(type='resnet',
                                                 resnet=backbones_cfg.ResNet(
                                                     model_id=model_id,
                                                     se_ratio=0.0))
        norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99,
                                                           norm_epsilon=1e-5,
                                                           use_sync_bn=False)

        factory_network = factory.build_backbone(
            input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
            backbone_config=backbone_config,
            norm_activation_config=norm_activation_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()

        self.assertEqual(network_config, factory_network_config)
    def test_sync_bn_multiple_devices(self, strategy, use_sync_bn):
        """Test for sync bn on TPU and GPU devices."""
        inputs = np.random.rand(64, 128, 128, 3)

        tf.keras.backend.set_image_data_format('channels_last')

        with strategy.scope():
            backbone = backbones.ResNet(model_id=50, use_sync_bn=use_sync_bn)

            model = classification_model.ClassificationModel(
                backbone=backbone,
                num_classes=1000,
                dropout_rate=0.2,
            )
            _ = model(inputs)
  def test_serialize_deserialize(self):
    """Validate the classification net can be serialized and deserialized."""

    tf.keras.backend.set_image_data_format('channels_last')
    backbone = backbones.ResNet(model_id=50)

    model = classification_model.ClassificationModel(
        backbone=backbone, num_classes=1000)

    config = model.get_config()
    new_model = classification_model.ClassificationModel.from_config(config)

    # Validate that the config can be forced to JSON.
    _ = new_model.to_json()

    # If the serialization was successful, the new config should match the old.
    self.assertAllEqual(model.get_config(), new_model.get_config())
    def test_data_format_gpu(self, strategy, data_format, input_dim):
        """Test for different data formats on GPU devices."""
        if data_format == 'channels_last':
            inputs = np.random.rand(2, 128, 128, input_dim)
        else:
            inputs = np.random.rand(2, input_dim, 128, 128)
        input_specs = tf.keras.layers.InputSpec(shape=inputs.shape)

        tf.keras.backend.set_image_data_format(data_format)

        with strategy.scope():
            backbone = backbones.ResNet(model_id=50, input_specs=input_specs)

            model = classification_model.ClassificationModel(
                backbone=backbone,
                num_classes=1000,
                input_specs=input_specs,
            )
            _ = model(inputs)
Esempio n. 6
0
    def test_serialize_deserialize(self):
        """Validate the network can be serialized and deserialized."""
        num_classes = 3
        backbone = backbones.ResNet(model_id=50)
        decoder = fpn.FPN(input_specs=backbone.output_specs,
                          min_level=3,
                          max_level=7)
        head = segmentation_heads.SegmentationHead(num_classes, level=3)
        model = segmentation_model.SegmentationModel(backbone=backbone,
                                                     decoder=decoder,
                                                     head=head)

        config = model.get_config()
        new_model = segmentation_model.SegmentationModel.from_config(config)

        # Validate that the config can be forced to JSON.
        _ = new_model.to_json()

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(model.get_config(), new_model.get_config())
    def test_resnet_network_creation(self, input_size, resnet_model_id,
                                     activation):
        """Test for creation of a ResNet-50 classifier."""
        inputs = np.random.rand(2, input_size, input_size, 3)

        tf.keras.backend.set_image_data_format('channels_last')

        backbone = backbones.ResNet(model_id=resnet_model_id,
                                    activation=activation)
        self.assertEqual(backbone.count_params(), 23561152)

        num_classes = 1000
        model = classification_model.ClassificationModel(
            backbone=backbone,
            num_classes=num_classes,
            dropout_rate=0.2,
        )
        self.assertEqual(model.count_params(), 25610152)

        logits = model(inputs)
        self.assertAllEqual([2, num_classes], logits.numpy().shape)
Esempio n. 8
0
    def test_segmentation_network_creation(self, input_size, level):
        """Test for creation of a segmentation network."""
        num_classes = 10
        inputs = np.random.rand(2, input_size, input_size, 3)
        tf.keras.backend.set_image_data_format('channels_last')
        backbone = backbones.ResNet(model_id=50)

        decoder = fpn.FPN(input_specs=backbone.output_specs,
                          min_level=2,
                          max_level=7)
        head = segmentation_heads.SegmentationHead(num_classes, level=level)

        model = segmentation_model.SegmentationModel(backbone=backbone,
                                                     decoder=decoder,
                                                     head=head)

        logits = model(inputs)
        self.assertAllEqual([
            2, input_size // (2**level), input_size // (2**level), num_classes
        ],
                            logits.numpy().shape)
Esempio n. 9
0
def build_backbone(input_specs: tf.keras.layers.InputSpec,
                   model_config,
                   l2_regularizer: tf.keras.regularizers.Regularizer = None):
    """Builds backbone from a config.

  Args:
    input_specs: tf.keras.layers.InputSpec.
    model_config: a OneOfConfig. Model config.
    l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.

  Returns:
    tf.keras.Model instance of the backbone.
  """
    backbone_type = model_config.backbone.type
    backbone_cfg = model_config.backbone.get()
    norm_activation_config = model_config.norm_activation

    if backbone_type == 'resnet':
        backbone = backbones.ResNet(
            model_id=backbone_cfg.model_id,
            input_specs=input_specs,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif backbone_type == 'efficientnet':
        backbone = backbones.EfficientNet(
            model_id=backbone_cfg.model_id,
            input_specs=input_specs,
            stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
            se_ratio=backbone_cfg.se_ratio,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif backbone_type == 'spinenet':
        model_id = backbone_cfg.model_id
        if model_id not in spinenet.SCALING_MAP:
            raise ValueError(
                'SpineNet-{} is not a valid architecture.'.format(model_id))
        scaling_params = spinenet.SCALING_MAP[model_id]

        backbone = backbones.SpineNet(
            input_specs=input_specs,
            min_level=model_config.min_level,
            max_level=model_config.max_level,
            endpoints_num_filters=scaling_params['endpoints_num_filters'],
            resample_alpha=scaling_params['resample_alpha'],
            block_repeats=scaling_params['block_repeats'],
            filter_size_scale=scaling_params['filter_size_scale'],
            kernel_regularizer=l2_regularizer,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon)
    elif backbone_type == 'revnet':
        backbone = backbones.RevNet(
            model_id=backbone_cfg.model_id,
            input_specs=input_specs,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    else:
        raise ValueError('Backbone {!r} not implement'.format(backbone_type))

    return backbone