Ejemplo n.º 1
0
  def __init__(self,
               data,
               model_export_format,
               model_spec,
               shuffle=True,
               train_whole_model=False,
               validation_ratio=0.1,
               test_ratio=0.1,
               hparams=lib.get_default_hparams()):
    """Init function for ImageClassifier class.

    Including splitting the raw input data into train/eval/test sets and
    selecting the exact NN model to be used.

    Args:
      data: Raw data that could be splitted for training / validation / testing.
      model_export_format: Model export format such as saved_model / tflite.
      model_spec: Specification for the model.
      shuffle: Whether the data should be shuffled.
      train_whole_model: If true, the Hub module is trained together with the
        classification layer on top. Otherwise, only train the top
        classification layer.
      validation_ratio: The ratio of validation data to be splitted.
      test_ratio: The ratio of test data to be splitted.
      hparams: A namedtuple of hyperparameters. This function expects
        .dropout_rate: The fraction of the input units to drop, used in dropout
          layer.
    """
    super(ImageClassifier,
          self).__init__(data, model_export_format, model_spec, shuffle,
                         train_whole_model, validation_ratio, test_ratio)

    # Gets pre_trained models.
    if model_export_format != mef.ModelExportFormat.TFLITE:
      raise ValueError('Model export mode %s is not supported currently.' %
                       str(model_export_format))
    self.pre_trained_model_spec = model_spec

    # Generates training, validation and testing data.
    if validation_ratio + test_ratio >= 1.0:
      raise ValueError(
          'The total ratio for validation and test data should be less than 1.0.'
      )

    self.validation_data, rest_data = data.split(
        validation_ratio, shuffle=shuffle)
    self.test_data, self.train_data = rest_data.split(
        test_ratio, shuffle=shuffle)

    # Checks dataset parameter.
    if self.train_data.size == 0:
      raise ValueError('Training dataset is empty.')

    # Creates the classifier model for retraining.
    module_layer = hub.KerasLayer(
        self.pre_trained_model_spec.uri, trainable=train_whole_model)
    self.model = lib.build_model(module_layer, hparams,
                                 self.pre_trained_model_spec.input_image_shape,
                                 data.num_classes)
Ejemplo n.º 2
0
  def _create_model(self, hparams=None):
    """Creates the classifier model for retraining."""
    hparams = self._get_hparams_or_default(hparams)

    module_layer = hub_loader.HubKerasLayerV1V2(
        self.model_spec.uri, trainable=hparams.do_fine_tuning)
    return lib.build_model(module_layer, hparams,
                           self.model_spec.input_image_shape, self.num_classes)
Ejemplo n.º 3
0
  def _create_model(self, hparams=None):
    """Creates the classifier model for retraining."""
    if hparams is None:
      hparams = self.hparams

    module_layer = hub.KerasLayer(
        self.model_spec.uri, trainable=hparams.do_fine_tuning)
    return lib.build_model(module_layer, hparams,
                           self.model_spec.input_image_shape,
                           self.data.num_classes)
Ejemplo n.º 4
0
    def create_model(self, hparams=None, with_loss_and_metrics=False):
        """Creates the classifier model for retraining."""
        hparams = self._get_hparams_or_default(hparams)

        module_layer = hub_loader.HubKerasLayerV1V2(
            self.model_spec.uri, trainable=hparams.do_fine_tuning)
        self.model = hub_lib.build_model(module_layer, hparams,
                                         self.model_spec.input_image_shape,
                                         self.num_classes)
        if with_loss_and_metrics:
            # Adds loss and metrics in the keras model.
            self.model.compile(loss=tf.keras.losses.CategoricalCrossentropy(
                label_smoothing=0.1),
                               metrics=['accuracy'])