Example #1
0
    def _build_model(self):
        input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
                                                self._input_image_size + [3])

        return factory.build_classification_model(
            input_specs=input_specs,
            model_config=self.params.task.model,
            l2_regularizer=None)
Example #2
0
    def build_model(self, skip_logits_layer=False):
        input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
                                                self._input_image_size + [3])

        self._model = factory.build_classification_model(
            input_specs=input_specs,
            model_config=self._params.task.model,
            l2_regularizer=None,
            skip_logits_layer=skip_logits_layer)

        return self._model
Example #3
0
 def test_builder(self, backbone_type, input_size, weight_decay):
     num_classes = 2
     input_specs = tf.keras.layers.InputSpec(
         shape=[None, input_size[0], input_size[1], 3])
     model_config = classification_cfg.ImageClassificationModel(
         num_classes=num_classes,
         backbone=backbones.Backbone(type=backbone_type))
     l2_regularizer = (tf.keras.regularizers.l2(weight_decay)
                       if weight_decay else None)
     _ = factory.build_classification_model(input_specs=input_specs,
                                            model_config=model_config,
                                            l2_regularizer=l2_regularizer)
Example #4
0
def export_model_to_tfhub(params, batch_size, input_image_size,
                          skip_logits_layer, checkpoint_path, export_path):
    """Export an image classification model to TF-Hub."""
    input_specs = tf.keras.layers.InputSpec(shape=[batch_size] +
                                            input_image_size + [3])

    model = factory.build_classification_model(
        input_specs=input_specs,
        model_config=params.task.model,
        l2_regularizer=None,
        skip_logits_layer=skip_logits_layer)
    checkpoint = tf.train.Checkpoint(model=model)
    checkpoint.restore(checkpoint_path).assert_existing_objects_matched()
    model.save(export_path, include_optimizer=False, save_format='tf')
Example #5
0
    def build_model(self):
        """Builds classification model."""
        input_specs = tf.keras.layers.InputSpec(
            shape=[None] + self.task_config.model.input_size)

        l2_weight_decay = self.task_config.losses.l2_weight_decay
        # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
        # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
        # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
        l2_regularizer = (tf.keras.regularizers.l2(l2_weight_decay / 2.0)
                          if l2_weight_decay else None)

        model = factory.build_classification_model(
            input_specs=input_specs,
            model_config=self.task_config.model,
            l2_regularizer=l2_regularizer)
        return model
def create_classification_export_module(params: cfg.ExperimentConfig,
                                        input_type: str,
                                        batch_size: int,
                                        input_image_size: List[int],
                                        num_channels: int = 3):
    """Creats classification export module."""
    input_signature = export_utils.get_image_input_signatures(
        input_type, batch_size, input_image_size, num_channels)
    input_specs = tf.keras.layers.InputSpec(shape=[batch_size] +
                                            input_image_size + [num_channels])

    model = factory.build_classification_model(input_specs=input_specs,
                                               model_config=params.task.model,
                                               l2_regularizer=None)

    def preprocess_fn(inputs):
        image_tensor = export_utils.parse_image(inputs, input_type,
                                                input_image_size, num_channels)
        # If input_type is `tflite`, do not apply image preprocessing.
        if input_type == 'tflite':
            return image_tensor

        def preprocess_image_fn(inputs):
            return classification_input.Parser.inference_fn(
                inputs, input_image_size, num_channels)

        images = tf.map_fn(preprocess_image_fn,
                           elems=image_tensor,
                           fn_output_signature=tf.TensorSpec(
                               shape=input_image_size + [num_channels],
                               dtype=tf.float32))

        return images

    def postprocess_fn(logits):
        probs = tf.nn.softmax(logits)
        return {'logits': logits, 'probs': probs}

    export_module = export_base.ExportModule(params,
                                             model=model,
                                             input_signature=input_signature,
                                             preprocessor=preprocess_fn,
                                             postprocessor=postprocess_fn)
    return export_module