def _test_export_to_tflite(self, model):
    tflite_output_file = os.path.join(self.get_temp_dir(), 'model.tflite')
    labels_output_file = os.path.join(self.get_temp_dir(), 'label')
    model.export(tflite_output_file, labels_output_file)
    labels = self._load_labels(labels_output_file)
    self.assertEqual(labels, ['cyan', 'magenta', 'yellow'])
    lite_model = self._load_lite_model(tflite_output_file)

    if compat.get_tf_behavior() == 1:
      image_placeholder = tf.compat.v1.placeholder(
          tf.uint8, [1, self.IMAGE_SIZE, self.IMAGE_SIZE, 3])
      label_placeholder = tf.compat.v1.placeholder(tf.int32, [1])
      image_tensor, _ = model.preprocess(image_placeholder, label_placeholder)
      with tf.compat.v1.Session() as sess:
        for i, (class_name, rgb) in enumerate(self.CMY_NAMES_AND_RGB_VALUES):
          input_image = np.expand_dims(_fill_image(rgb, self.IMAGE_SIZE), 0)
          image = sess.run(
              image_tensor,
              feed_dict={
                  image_placeholder: input_image,
                  label_placeholder: [i]
              })
          output_batch = lite_model(image)
          prediction = labels[np.argmax(output_batch[0])]
          self.assertEqual(class_name, prediction)
    else:
      for i, (class_name, rgb) in enumerate(self.CMY_NAMES_AND_RGB_VALUES):
        input_batch = np.expand_dims(_fill_image(rgb, self.IMAGE_SIZE), 0)
        image, _ = model.preprocess(input_batch, i)
        image = image.numpy()
        output_batch = lite_model(image)
        prediction = labels[np.argmax(output_batch[0])]
        self.assertEqual(class_name, prediction)
  def _export_tflite(self, tflite_filename, label_filename, quantized=False):
    """Converts the retrained model to tflite format and saves it.

    Args:
      tflite_filename: File name to save tflite model.
      label_filename: File name to save labels.
      quantized: boolean, if True, save quantized model.
    """
    if compat.get_tf_behavior() == 1:
      with tempfile.TemporaryDirectory() as temp_dir:
        save_path = os.path.join(temp_dir, 'saved_model')
        self.model.save(save_path, include_optimizer=False, save_format='tf')
        converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
            save_path)
    else:
      converter = tf.lite.TFLiteConverter.from_keras_model(self.model)

    if quantized:
      converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
    tflite_model = converter.convert()

    with tf.io.gfile.GFile(tflite_filename, 'wb') as f:
      f.write(tflite_model)

    with tf.io.gfile.GFile(label_filename, 'w') as f:
      f.write('\n'.join(self.index_to_label))

    tf.compat.v1.logging.info('Export to tflite model %s, saved labels in %s.',
                              tflite_filename, label_filename)
    def _export_tflite(self,
                       tflite_filename,
                       label_filename,
                       quantized=False,
                       quantization_steps=None,
                       representative_data=None):
        """Converts the retrained model to tflite format and saves it.

    Args:
      tflite_filename: File name to save tflite model.
      label_filename: File name to save labels.
      quantized: boolean, if True, save quantized model.
      quantization_steps: Number of post-training quantization calibration steps
        to run. Used only if `quantized` is True.
      representative_data: Representative data used for post-training
        quantization. Used only if `quantized` is True.
    """
        if compat.get_tf_behavior() == 1:
            with tempfile.TemporaryDirectory() as temp_dir:
                save_path = os.path.join(temp_dir, 'saved_model')
                self.model.save(save_path,
                                include_optimizer=False,
                                save_format='tf')
                converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
                    save_path)
        else:
            converter = tf.lite.TFLiteConverter.from_keras_model(self.model)

        if quantized:
            if quantization_steps is None:
                quantization_steps = DEFAULT_QUANTIZATION_STEPS
            if representative_data is None:
                raise ValueError(
                    'representative_data couldn\'t be None if model is quantized.'
                )
            ds = self._gen_dataset(representative_data,
                                   batch_size=1,
                                   is_training=False)
            converter.representative_dataset = tf.lite.RepresentativeDataset(
                get_representative_dataset_gen(ds, quantization_steps))

            converter.optimizations = [tf.lite.Optimize.DEFAULT]
            converter.inference_input_type = tf.uint8
            converter.inference_output_type = tf.uint8
            converter.target_spec.supported_ops = [
                tf.lite.OpsSet.TFLITE_BUILTINS_INT8
            ]
        tflite_model = converter.convert()

        with tf.io.gfile.GFile(tflite_filename, 'wb') as f:
            f.write(tflite_model)

        with tf.io.gfile.GFile(label_filename, 'w') as f:
            f.write('\n'.join(self.index_to_label))

        tf.compat.v1.logging.info(
            'Export to tflite model %s, saved labels in %s.', tflite_filename,
            label_filename)
예제 #4
0
def create(train_data,
           model_export_format=mef.ModelExportFormat.TFLITE,
           model_spec=ms.AverageWordVecModelSpec(),
           shuffle=False,
           batch_size=32,
           epochs=2,
           validation_data=None):
    """Loads data and train the model for test classification.

  Args:
    train_data: Raw data for training.
    model_export_format: Model export format such as saved_model / tflite.
    model_spec: Specification for the model.
    shuffle: Whether the data should be shuffled.
    batch_size: Batch size for training.
    epochs: Number of epochs for training.
    validation_data: Validation data. If None, skips validation process.

  Returns:
    TextClassifier
  """
    if compat.get_tf_behavior() not in model_spec.compat_tf_versions:
        raise ValueError(
            'Incompatible versions. Expect {}, but got {}.'.format(
                model_spec.compat_tf_versions, compat.get_tf_behavior()))

    text_classifier = TextClassifier(train_data,
                                     model_export_format,
                                     model_spec,
                                     train_data.index_to_label,
                                     train_data.num_classes,
                                     shuffle=shuffle)

    tf.compat.v1.logging.info('Retraining the models...')
    text_classifier.train(train_data, validation_data, epochs, batch_size)

    return text_classifier
예제 #5
0
def create(train_data,
           model_export_format=mef.ModelExportFormat.TFLITE,
           model_spec=ms.mobilenet_v2_spec,
           shuffle=False,
           validation_data=None,
           batch_size=None,
           epochs=None,
           train_whole_model=None,
           dropout_rate=None,
           learning_rate=None,
           momentum=None):
    """Loads data and retrains the model based on data for image classification.

  Args:
    train_data: Training data.
    model_export_format: Model export format such as saved_model / tflite.
    model_spec: Specification for the model.
    shuffle: Whether the data should be shuffled.
    validation_data: Validation data. If None, skips validation process.
    batch_size: Number of samples per training step.
    epochs: Number of epochs for training.
    train_whole_model: If true, the Hub module is trained together with the
      classification layer on top. Otherwise, only train the top classification
      layer.
    dropout_rate: the rate for dropout.
    learning_rate: a Python float forwarded to the optimizer.
    momentum: a Python float forwarded to the optimizer.
  Returns:
    An instance of ImageClassifier class.
  """
    if compat.get_tf_behavior() not in model_spec.compat_tf_versions:
        raise ValueError(
            'Incompatible versions. Expect {}, but got {}.'.format(
                model_spec.compat_tf_versions, compat.get_tf_behavior()))

    # The hyperparameters for make_image_classifier by tensorflow hub.
    hparams = lib.get_default_hparams()
    if batch_size is not None:
        hparams = hparams._replace(batch_size=batch_size)
    if epochs is not None:
        hparams = hparams._replace(train_epochs=epochs)
    if train_whole_model is not None:
        hparams = hparams._replace(do_fine_tuning=train_whole_model)
    if dropout_rate is not None:
        hparams = hparams._replace(dropout_rate=dropout_rate)
    if learning_rate is not None:
        hparams = hparams._replace(learning_rate=learning_rate)
    if momentum is not None:
        hparams = hparams._replace(momentum=momentum)

    image_classifier = ImageClassifier(model_export_format,
                                       model_spec,
                                       train_data.index_to_label,
                                       train_data.num_classes,
                                       shuffle=shuffle,
                                       hparams=hparams)

    tf.compat.v1.logging.info('Retraining the models...')
    image_classifier.train(train_data, validation_data)

    return image_classifier