Beispiel #1
0
    def test_export_constant(self):

        FOO = 1  # pylint: disable=invalid-name,unused-variable
        api_util.mm_export('foo.FOO').export_constant(__name__, 'FOO')

        self.assertLen(api_util.NAME_TO_SYMBOL, 1)
        expected = api_util.Symbol('foo.FOO', ['foo', 'FOO'], None, '__main__',
                                   'FOO')
        self.assertEqual(api_util.NAME_TO_SYMBOL['foo.FOO'], expected)
Beispiel #2
0
    def test_call_global_function(self):
        def test_func():
            """Func to test export."""
            pass

        exportor = api_util.mm_export('foo.bar.test_func')
        ret_func = exportor(test_func)
        self.assertEqual(ret_func, test_func)
        func = api_util.NAME_TO_SYMBOL['foo.bar.test_func']
        self.assertEqual(func.gen_import(), 'from __main__ import test_func')
        self.assertEqual(func.get_package_name(), 'foo.bar')

        exportor = api_util.mm_export('fn')
        exportor(test_func)
        func = api_util.NAME_TO_SYMBOL['fn']
        self.assertEqual(func.gen_import(),
                         'from __main__ import test_func as fn')
        self.assertEqual(func.get_package_name(), '')
Beispiel #3
0
      tflite_model = converter.convert()

      with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
        f.write(tflite_model)


efficientdet_lite0_spec = functools.partial(
    EfficientDetModelSpec,
    model_name='efficientdet-lite0',
    uri='https://tfhub.dev/tensorflow/efficientdet/lite0/feature-vector/1',
)
efficientdet_lite0_spec.__doc__ = util.wrap_doc(
    EfficientDetModelSpec,
    'Creates EfficientDet-Lite0 model spec. See also: `tflite_model_maker.object_detector.EfficientDetSpec`.'
)
mm_export('object_detector.EfficientDetLite0Spec').export_constant(
    __name__, 'efficientdet_lite0_spec')

efficientdet_lite1_spec = functools.partial(
    EfficientDetModelSpec,
    model_name='efficientdet-lite1',
    uri='https://tfhub.dev/tensorflow/efficientdet/lite1/feature-vector/1',
)
efficientdet_lite1_spec.__doc__ = util.wrap_doc(
    EfficientDetModelSpec,
    'Creates EfficientDet-Lite1 model spec. See also: `tflite_model_maker.object_detector.EfficientDetSpec`.'
)
mm_export('object_detector.EfficientDetLite1Spec').export_constant(
    __name__, 'efficientdet_lite1_spec')

efficientdet_lite2_spec = functools.partial(
    EfficientDetModelSpec,
      do_train: Whether to run training.

    Returns:
      An instance based on ObjectDetector.
    """
    model_spec = ms.get(model_spec)
    if epochs is not None:
      model_spec.config.num_epochs = epochs
    if batch_size is not None:
      model_spec.config.batch_size = batch_size
    if train_whole_model:
      model_spec.config.var_freeze_expr = None
    if compat.get_tf_behavior() not in model_spec.compat_tf_versions:
      raise ValueError('Incompatible versions. Expect {}, but got {}.'.format(
          model_spec.compat_tf_versions, compat.get_tf_behavior()))

    object_detector = cls(model_spec, train_data.label_map, train_data)

    if do_train:
      tf.compat.v1.logging.info('Retraining the models...')
      object_detector.train(train_data, validation_data, epochs, batch_size)
    else:
      object_detector.create_model()

    return object_detector


# Shortcut function.
create = ObjectDetector.create
mm_export('object_detector.create').export_constant(__name__, 'create')