Пример #1
0
  def test_rnn_cell(self):
    export_dir = self.get_temp_dir()
    export_binary = resource_loader.get_path_to_datafile(
        "export_rnn_cell")
    self.assertCommandSucceeded(export_binary, export_dir=export_dir)

    use_binary = resource_loader.get_path_to_datafile("use_rnn_cell")
    self.assertCommandSucceeded(use_binary, model_dir=export_dir)
Пример #2
0
  def test_mnist_cnn(self):
    export_dir = self.get_temp_dir()
    export_binary = resource_loader.get_path_to_datafile("export_mnist_cnn")
    self.assertCommandSucceeded(
        export_binary, export_dir=export_dir, fast_test_mode="true")

    use_binary = resource_loader.get_path_to_datafile("use_mnist_cnn")
    self.assertCommandSucceeded(
        use_binary, export_dir=export_dir, fast_test_mode="true")
Пример #3
0
  def test_text_embedding_in_sequential_keras(self):
    export_dir = self.get_temp_dir()
    export_binary = resource_loader.get_path_to_datafile(
        "export_simple_text_embedding")
    self.assertCommandSucceeded(export_binary, export_dir=export_dir)

    use_binary = resource_loader.get_path_to_datafile(
        "use_model_in_sequential_keras")
    self.assertCommandSucceeded(use_binary, model_dir=export_dir)
Пример #4
0
def load_op_library(path):
  """Loads a contrib op library from the given path.

  NOTE(mrry): On Windows, we currently assume that some contrib op
  libraries are statically linked into the main TensorFlow Python
  extension DLL - use dynamically linked ops if the .so is present.

  Args:
    path: An absolute path to a shared object file.

  Returns:
    A Python module containing the Python wrappers for Ops defined in the
    plugin.
  """
  if os.name == 'nt':
    # To avoid makeing every user_ops aware of windows, re-write
    # the file extension from .so to .dll.
    path = re.sub(r'\.so$', '.dll', path)

    # Currently we have only some user_ops as dlls on windows - don't try
    # to load them if the dll is not found.
    # TODO(mrry): Once we have all of them this check should be removed.
    if not os.path.exists(path):
      return None
  path = resource_loader.get_path_to_datafile(path)
  ret = load_library.load_op_library(path)
  assert ret, 'Could not load %s' % path
  return ret
Пример #5
0
 def testInvokeBeforeReady(self):
   interpreter = interpreter_wrapper.Interpreter(
       model_path=resource_loader.get_path_to_datafile(
           'testdata/permute_float.tflite'))
   with self.assertRaisesRegexp(RuntimeError,
                                'Invoke called on model that is not ready'):
     interpreter.invoke()
Пример #6
0
 def setUp(self):
   self.interpreter = interpreter_wrapper.Interpreter(
       model_path=resource_loader.get_path_to_datafile(
           'testdata/permute_float.tflite'))
   self.interpreter.allocate_tensors()
   self.input0 = self.interpreter.get_input_details()[0]['index']
   self.initial_data = np.array([[-1., -2., -3., -4.]], np.float32)
Пример #7
0
  def testFloat(self):
    interpreter = interpreter_wrapper.Interpreter(
        model_path=resource_loader.get_path_to_datafile(
            'testdata/permute_float.tflite'))
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertEqual(1, len(input_details))
    self.assertEqual('input', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertTrue(([1, 4] == input_details[0]['shape']).all())
    self.assertEqual((0.0, 0), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertEqual(1, len(output_details))
    self.assertEqual('output', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertTrue(([1, 4] == output_details[0]['shape']).all())
    self.assertEqual((0.0, 0), output_details[0]['quantization'])

    test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
    expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
    interpreter.set_tensor(input_details[0]['index'], test_input)
    interpreter.invoke()

    output_data = interpreter.get_tensor(output_details[0]['index'])
    self.assertTrue((expected_output == output_data).all())
Пример #8
0
  def testUint8(self):
    model_path = resource_loader.get_path_to_datafile(
        'testdata/permute_uint8.tflite')
    with io.open(model_path, 'rb') as model_file:
      data = model_file.read()

    interpreter = interpreter_wrapper.Interpreter(model_content=data)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertEqual(1, len(input_details))
    self.assertEqual('input', input_details[0]['name'])
    self.assertEqual(np.uint8, input_details[0]['dtype'])
    self.assertTrue(([1, 4] == input_details[0]['shape']).all())
    self.assertEqual((1.0, 0), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertEqual(1, len(output_details))
    self.assertEqual('output', output_details[0]['name'])
    self.assertEqual(np.uint8, output_details[0]['dtype'])
    self.assertTrue(([1, 4] == output_details[0]['shape']).all())
    self.assertEqual((1.0, 0), output_details[0]['quantization'])

    test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
    expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
    interpreter.resize_tensor_input(input_details[0]['index'],
                                    test_input.shape)
    interpreter.allocate_tensors()
    interpreter.set_tensor(input_details[0]['index'], test_input)
    interpreter.invoke()

    output_data = interpreter.get_tensor(output_details[0]['index'])
    self.assertTrue((expected_output == output_data).all())
Пример #9
0
  def testString(self):
    interpreter = interpreter_wrapper.Interpreter(
        model_path=resource_loader.get_path_to_datafile(
            'testdata/gather_string.tflite'))
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertEqual(2, len(input_details))
    self.assertEqual('input', input_details[0]['name'])
    self.assertEqual(np.string_, input_details[0]['dtype'])
    self.assertTrue(([10] == input_details[0]['shape']).all())
    self.assertEqual((0.0, 0), input_details[0]['quantization'])
    self.assertEqual('indices', input_details[1]['name'])
    self.assertEqual(np.int64, input_details[1]['dtype'])
    self.assertTrue(([3] == input_details[1]['shape']).all())
    self.assertEqual((0.0, 0), input_details[1]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertEqual(1, len(output_details))
    self.assertEqual('output', output_details[0]['name'])
    self.assertEqual(np.string_, output_details[0]['dtype'])
    self.assertTrue(([3] == output_details[0]['shape']).all())
    self.assertEqual((0.0, 0), output_details[0]['quantization'])

    test_input = np.array([1, 2, 3], dtype=np.int64)
    interpreter.set_tensor(input_details[1]['index'], test_input)

    test_input = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'])
    expected_output = np.array([b'b', b'c', b'd'])
    interpreter.set_tensor(input_details[0]['index'], test_input)
    interpreter.invoke()

    output_data = interpreter.get_tensor(output_details[0]['index'])
    self.assertTrue((expected_output == output_data).all())
Пример #10
0
def _load_library(name, op_list=None):
  """Loads a .so file containing the specified operators.

  Args:
    name: The name of the .so file to load.
    op_list: A list of names of operators that the library should have. If None
        then the .so file's contents will not be verified.

  Raises:
    NameError if one of the required ops is missing.
  """
  try:
    filename = resource_loader.get_path_to_datafile(name)
    library = load_library.load_op_library(filename)
    for expected_op in (op_list or []):
      for lib_op in library.OP_LIST.op:
        if lib_op.name == expected_op:
          break
      else:
        raise NameError(
          'Could not find operator %s in dynamic library %s' %
          (expected_op, name))
    return library
  except errors.NotFoundError:
    logging.warning('%s file could not be loaded.', name)
Пример #11
0
  def __init__(self):
    # TODO(aselle): make this work in the open source version with better
    # path.
    paths_to_try = [
        "../../../../flatbuffers/flatc",  # not bazel
        "../../../../external/flatbuffers/flatc"  # bazel
    ]
    for p in paths_to_try:
      self._flatc_path = resource_loader.get_path_to_datafile(p)
      if os.path.exists(self._flatc_path): break

    def FindSchema(base_name):
      return resource_loader.get_path_to_datafile("%s" % base_name)

    # Supported schemas for upgrade.
    self._schemas = [
        (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1),
        (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2),
        (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3),
        (3, FindSchema("schema_v3.fbs"), False, None)  # Non-callable by design.
    ]
    # Ensure schemas are sorted, and extract latest version and upgrade
    # dispatch function table.
    self._schemas.sort()
    self._new_version, self._new_schema = self._schemas[-1][:2]
    self._upgrade_dispatch = dict(
        (version, dispatch)
        for version, unused1, unused2, dispatch in self._schemas)
Пример #12
0
def _maybe_load_nccl_ops_so():
  """Loads nccl ops so if it hasn't been loaded already."""

  with _module_lock:
    global _nccl_ops_so
    if not _nccl_ops_so:
      _nccl_ops_so = loader.load_op_library(
          resource_loader.get_path_to_datafile('_nccl_ops.so'))
Пример #13
0
 def testInvalidIndex(self):
   interpreter = interpreter_wrapper.Interpreter(
       model_path=resource_loader.get_path_to_datafile(
           'testdata/permute_float.tflite'))
   interpreter.allocate_tensors()
   #Invalid tensor index passed.
   with self.assertRaisesRegexp(ValueError, 'Tensor with no shape found.'):
     interpreter._get_tensor_details(4)
Пример #14
0
 def assertCommandSucceeded(self, script_name, **flags):
   """Runs a test script via run_script."""
   run_script = resource_loader.get_path_to_datafile("run_script")
   command_parts = [run_script]
   for flag_key, flag_value in flags.items():
     command_parts.append("--%s=%s" % (flag_key, flag_value))
   env = dict(TF2_BEHAVIOR="enabled", SCRIPT_NAME=script_name)
   logging.info("Running: %s with environment flags %s" % (command_parts, env))
   subprocess.check_call(command_parts, env=dict(os.environ, **env))
Пример #15
0
 def _initObjectDetectionArgs(self):
   # Initializes the arguments required for the object detection model.
   self._graph_def_file = resource_loader.get_path_to_datafile(
       'testdata/tflite_graph.pb')
   self._input_arrays = ['normalized_input_image_tensor']
   self._output_arrays = [
       'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
       'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
   ]
   self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}
Пример #16
0
def Load():
  """Load the inference ops library and return the loaded module."""
  with _ops_lock:
    global _inference_ops
    if not _inference_ops:
      ops_path = resource_loader.get_path_to_datafile(INFERENCE_OPS_FILE)
      logging.info('data path: %s', ops_path)
      _inference_ops = load_library.load_op_library(ops_path)

      assert _inference_ops, 'Could not load inference_ops.so'
  return _inference_ops
Пример #17
0
def Load():
  """Load training ops library and return the loaded module."""
  with _ops_lock:
    global _training_ops
    if not _training_ops:
      ops_path = resource_loader.get_path_to_datafile(TRAINING_OPS_FILE)
      logging.info('data path: %s', ops_path)
      _training_ops = load_library.load_op_library(ops_path)

      assert _training_ops, 'Could not load _training_ops.so'
  return _training_ops
Пример #18
0
def Load():
    """Load the data ops library and return the loaded module."""
    with _ops_lock:
        global _data_ops
        if not _data_ops:
            ops_path = resource_loader.get_path_to_datafile(DATA_OPS_FILE)
            logging.info("data path: %s", ops_path)
            _data_ops = load_library.load_op_library(ops_path)

            assert _data_ops, "Could not load _data_ops.so"
    return _data_ops
Пример #19
0
  def test_empty_calibrator_gen(self):
    model_path = resource_loader.get_path_to_datafile(
        'test_data/mobilenet_like_model.bin')
    float_model = open(model_path, 'rb').read()
    quantizer = _calibrator.Calibrator(float_model)

    def empty_input_gen():
      for i in ():
        yield i

    with self.assertRaises(RuntimeError):
      quantizer.calibrate_and_quantize(empty_input_gen)
Пример #20
0
  def test_invalid_type_calibrator_gen(self):
    model_path = resource_loader.get_path_to_datafile(
        'test_data/mobilenet_like_model.bin')
    float_model = open(model_path, 'rb').read()
    quantizer = _calibrator.Calibrator(float_model)

    # Input generator with incorrect shape.
    def input_gen():
      for _ in range(10):
        yield np.ones(shape=(1, 5, 5, 3), dtype=np.int32)

    with self.assertRaises(ValueError):
      quantizer.calibrate_and_quantize(input_gen)
Пример #21
0
  def test_calibration_with_quantization(self):
    model_path = resource_loader.get_path_to_datafile(
        'test_data/mobilenet_like_model.bin')
    float_model = open(model_path, 'rb').read()
    quantizer = _calibrator.Calibrator(float_model)

    # Input generator for the model.
    def input_gen():
      for _ in range(10):
        yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]

    quantized_model = quantizer.calibrate_and_quantize(input_gen)
    self.assertIsNotNone(quantized_model)
Пример #22
0
  def test_invalid_shape_calibrator_gen(self):
    model_path = resource_loader.get_path_to_datafile(
        'test_data/mobilenet_like_model.bin')
    float_model = open(model_path, 'rb').read()
    quantizer = _calibrator.Calibrator(float_model)

    # Input generator with incorrect shape.
    def input_gen():
      for _ in range(10):
        yield [np.ones(shape=(1, 2, 2, 3), dtype=np.float32)]

    with self.assertRaisesWithRegexpMatch(ValueError, 'Dimension mismatch'):
      quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
                                       constants.FLOAT, False)
Пример #23
0
def zero_initializer(ref, use_locking=True, name="zero_initializer"):
  """Initialize 'ref' with all zeros, ref tensor should be uninitialized.
  If already initialized, you will get ValueError. This op is intended to
  save memory during initialization.
  Args:
    ref: ref of the tensor need to be zero initialized.
    name: optional name for this operation.
  Returns:
    ref that initialized.
  Raises:
    ValueError: If ref tensor is initialized.
  """
  loader.load_op_library(
      resource_loader.get_path_to_datafile("_variable_ops.so"))
  return gen_variable_ops.zero_initializer(ref, name=name)
Пример #24
0
def get_image(size):
  """Returns an image loaded into an np.ndarray with dims [1, size, size, 3].

  Args:
    size: Size of image.

  Returns:
    np.ndarray.
  """
  img_filename = _resource_loader.get_path_to_datafile(
      "testdata/grace_hopper.jpg")
  img = image.load_img(img_filename, target_size=(size, size))
  img_array = image.img_to_array(img)
  img_array = np.expand_dims(img_array, axis=0)
  return img_array
Пример #25
0
  def test_calibration_with_quantization_multiple_inputs(self):
    # Load multi add model from test data.
    # This model has 4 inputs of size (1, 8, 8, 3).
    model_path = resource_loader.get_path_to_datafile(
        '../../testdata/multi_add.bin')
    float_model = open(model_path, 'rb').read()
    quantizer = _calibrator.Calibrator(float_model)

    # Input generator for the model.
    def input_gen():
      for _ in range(10):
        yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)]

    quantized_model = quantizer.calibrate_and_quantize(input_gen)
    self.assertIsNotNone(quantized_model)
Пример #26
0
  def _initObjectDetectionArgs(self):
    # Initializes the arguments required for the object detection model.
    # Looks for the model file which is saved in a different location interally
    # and externally.
    filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb')
    if not os.path.exists(filename):
      filename = os.path.join(
          resource_loader.get_root_dir_with_all_resources(),
          '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb')
      if not os.path.exists(filename):
        raise IOError("File '{0}' does not exist.".format(filename))

    self._graph_def_file = filename
    self._input_arrays = ['normalized_input_image_tensor']
    self._output_arrays = [
        'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
        'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
    ]
    self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}
Пример #27
0
  def _run(self, sess, in_tensor, out_tensor, should_succeed):
    """Use toco binary to check conversion from graphdef to tflite.

    Args:
      sess: Active TensorFlow session containing graph.
      in_tensor: TensorFlow tensor to use as input.
      out_tensor: TensorFlow tensor to use as output.
      should_succeed: Whether this is a valid conversion.
    """
    # Build all protos and extract graphdef
    graph_def = sess.graph_def
    toco_flags = toco_flags_pb2.TocoFlags()
    toco_flags.input_format = toco_flags_pb2.TENSORFLOW_GRAPHDEF
    toco_flags.output_format = toco_flags_pb2.TFLITE
    toco_flags.inference_input_type = types_pb2.FLOAT
    toco_flags.inference_type = types_pb2.FLOAT
    toco_flags.allow_custom_ops = True
    model_flags = model_flags_pb2.ModelFlags()
    input_array = model_flags.input_arrays.add()
    input_array.name = TensorName(in_tensor)
    input_array.shape.dims.extend(map(int, in_tensor.get_shape()))
    model_flags.output_arrays.append(TensorName(out_tensor))
    # Shell out to run toco (in case it crashes)
    with tempfile.NamedTemporaryFile() as fp_toco, \
           tempfile.NamedTemporaryFile() as fp_model, \
           tempfile.NamedTemporaryFile() as fp_input, \
           tempfile.NamedTemporaryFile() as fp_output:
      fp_model.write(model_flags.SerializeToString())
      fp_toco.write(toco_flags.SerializeToString())
      fp_input.write(graph_def.SerializeToString())
      fp_model.flush()
      fp_toco.flush()
      fp_input.flush()
      tflite_bin = resource_loader.get_path_to_datafile("toco_from_protos")
      cmdline = " ".join([
          tflite_bin, fp_model.name, fp_toco.name, fp_input.name, fp_output.name
      ])
      exitcode = os.system(cmdline)
      if exitcode == 0:
        stuff = fp_output.read()
        self.assertEqual(stuff is not None, should_succeed)
      else:
        self.assertFalse(should_succeed)
Пример #28
0
def load_trt_ops():
  """Load TF-TRT op libraries so if it hasn't been loaded already."""
  global _trt_ops_so

  if platform.system() == "Windows":
    raise RuntimeError("Windows platforms are not supported")

  with _module_lock:
    if _trt_ops_so:
      return

    try:
      # pylint: disable=g-import-not-at-top,unused-variable
      # This registers the TRT ops, it doesn't require loading TRT library.
      from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op
      # pylint: enable=g-import-not-at-top,unused-variable
    except ImportError as e:
      print("**** Failed to import TF-TRT ops. This is because the binary was "
            "not built with CUDA or TensorRT enabled. ****")
      raise e

    # TODO(laigd): we should load TF-TRT kernels here as well after removing the
    # swig binding.
    try:
      # pylint: disable=g-import-not-at-top
      from tensorflow.python.framework import load_library
      from tensorflow.python.platform import resource_loader
      # pylint: enable=g-import-not-at-top

      _trt_ops_so = load_library.load_op_library(
          resource_loader.get_path_to_datafile("_trt_ops.so"))
    except errors.NotFoundError as e:
      no_trt_message = (
          "**** Failed to initialize TensorRT. This is either because the "
          "TensorRT installation path is not in LD_LIBRARY_PATH, or because "
          "you do not have it installed. If not installed, please go to "
          "https://developer.nvidia.com/tensorrt to download and install "
          "TensorRT ****")
      print(no_trt_message)
      raise e
Пример #29
0
def load_op_library(path):
  """Loads a contrib op library from the given path.

  NOTE(mrry): On Windows, we currently assume that contrib op
  libraries are statically linked into the main TensorFlow Python
  extension DLL.

  Args:
    path: An absolute path to a shared object file.

  Returns:
    A Python module containing the Python wrappers for Ops defined in the
    plugin.
  """
  if os.name != 'nt':
    path = resource_loader.get_path_to_datafile(path)
    ret = load_library.load_op_library(path)
    assert ret, 'Could not load %s' % path
    return ret
  else:
    # NOTE(mrry):
    return None
  def testUint8(self):
    interpreter = interpreter_wrapper.Interpreter(
        resource_loader.get_path_to_datafile('testdata/permute_uint8.tflite'))
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertEqual(1, len(input_details))
    self.assertEqual('input', input_details[0]['name'])
    self.assertEqual(np.uint8, input_details[0]['dtype'])
    self.assertTrue(([1, 4] == input_details[0]['shape']).all())

    output_details = interpreter.get_output_details()
    self.assertEqual(1, len(output_details))
    self.assertEqual('output', output_details[0]['name'])
    self.assertEqual(np.uint8, output_details[0]['dtype'])
    self.assertTrue(([1, 4] == output_details[0]['shape']).all())

    test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
    expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
    interpreter.set_tensor(input_details[0]['index'], test_input)
    interpreter.invoke()

    output_data = interpreter.get_tensor(output_details[0]['index'])
    self.assertTrue((expected_output == output_data).all())
Пример #31
0
    def add_summary(self, stamp_token, column, example_weights):
        """Adds quantile summary to its stream in resource."""
        summary = self._make_summary(column, example_weights)
        return gen_quantile_ops.quantile_accumulator_add_summaries(
            quantile_accumulator_handles=[self._quantile_accumulator_handle],
            stamp_token=stamp_token,
            summaries=[summary])

    def schedule_add_summary(self, stamp_token, column, example_weights):
        """Schedules to add a quantile summary to its stream in resource."""
        summary = self._make_summary(column, example_weights)
        return batch_ops_utils.ScheduledStampedResourceOp(
            op=gen_quantile_ops.quantile_accumulator_add_summaries,
            resource_handle=self._quantile_accumulator_handle,
            summaries=summary)

    def flush(self, stamp_token, next_stamp_token):
        """Finalizes quantile summary stream and resets it for next iteration."""
        return gen_quantile_ops.quantile_accumulator_flush(
            quantile_accumulator_handle=self._quantile_accumulator_handle,
            stamp_token=stamp_token,
            next_stamp_token=next_stamp_token)


# Conditionally load ops, they might already be statically linked in.
try:
    _quantile_ops = loader.load_op_library(
        resource_loader.get_path_to_datafile("_quantile_ops.so"))
except (errors.NotFoundError, IOError):
    print("Error loading _quantile_ops.so")
Пример #32
0
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Custom ops used by tensorforest."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.contrib.tensor_forest.python.ops.gen_tensor_forest_ops import *
# pylint: enable=wildcard-import
from tensorflow.contrib.util import loader
from tensorflow.python.platform import resource_loader

_tensor_forest_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile('_tensor_forest_ops.so'))
Пример #33
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Custom Aggregator op is for collecting numeric metrics from the given input."""

from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import custom_aggregator_op_wrapper
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader

_custom_aggregator_op = load_library.load_op_library(
    resource_loader.get_path_to_datafile('_custom_aggregator_op.so'))


def custom_aggregator(input_tensor, tensor_id: str):
  """Creates custom aggregator op that collects numeric metrics from the tensor.

  Args:
    input_tensor: Tensor to be scanned through this operator. This tensor will
      be bypassed to the output tensor of this operator.
    tensor_id: String, the identity of the tensor to be scanned.

  Returns:
    A `Tensor` of the same value as `input_tensor`.

  Raises:
    ValueError: If the given type of `input_tensor` is not float32.
Пример #34
0
from six import iteritems
from six import string_types

from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
from tensorflow.contrib.util import loader
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import resource_loader

_bigtable_so = loader.load_op_library(
    resource_loader.get_path_to_datafile("_bigtable.so"))


class BigtableClient(object):
    """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF.

  BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the
  `table` method to open a Bigtable table.
  """
    def __init__(self,
                 project_id,
                 instance_id,
                 connection_pool_size=None,
                 max_receive_message_size=None):
        """Creates a BigtableClient that can be used to open connections to tables.
from tensorflow.contrib.image.ops import gen_image_ops
from tensorflow.contrib.util import loader
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import resource_loader

_image_ops_so = loader.load_op_library(
    resource_loader.get_path_to_datafile("_image_ops.so"))

_IMAGE_DTYPES = set([
    dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32,
    dtypes.float64
])

ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ImageProjectiveTransformV2")(
    common_shapes.call_cpp_shape_fn)


# TODO(ringwalt): Support a "reshape" (name used by SciPy) or "expand" (name
# used by PIL, maybe more readable) mode, which determines the correct
# output_shape and translation for the transform.
Пример #36
0
# limitations under the License.
# =============================================================================
"""Encoding and decoding audio using FFmpeg."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py
from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py
from tensorflow.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader

_ffmpeg_so = loader.load_op_library(
    resource_loader.get_path_to_datafile('ffmpeg.so'))


def decode_audio(contents, file_format=None, samples_per_second=None,
                 channel_count=None):
  """Create an op that decodes the contents of an audio file.

  Note that ffmpeg is free to select the "best" audio track from an mp4.
  https://trac.ffmpeg.org/wiki/Map

  Args:
    contents: The binary contents of the audio file to decode. This is a
        scalar.
    file_format: A string or scalar string tensor specifying which
        format the contents will conform to. This can be mp3, mp4, ogg,
        or wav.
Пример #37
0
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python helper for loading kinesis ops and kernels."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.util import loader
from tensorflow.python.platform import resource_loader

_dataset_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
Пример #38
0
# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from tensorflow.python.platform import resource_loader
import sys

ops = tf.load_op_library(resource_loader.get_path_to_datafile("distribute.so"))

# Importing all the symbols.
module = sys.modules[__name__]
for name, value in ops.__dict__.items():
  if "__" in name:
    continue
  setattr(module, name, value)
Пример #39
0
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Range coding operations."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
from tensorflow_compression.python.ops import namespace_helper

ops = namespace_helper.get_ops(
    load_library.load_op_library(
        resource_loader.get_path_to_datafile("../../../_range_coding_ops.so")))

globals().update(ops)
__all__ = list(ops)
Пример #40
0
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
from tensorflow.python.training.tracking import tracking
from tensorflow_text.python.ops.tokenization import Detokenizer
from tensorflow_text.python.ops.tokenization import TokenizerWithOffsets

from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
gen_sentencepiece_tokenizer = load_library.load_op_library(
    resource_loader.get_path_to_datafile('_sentencepiece_tokenizer.so'))  # pylint: disable=g-bad-import-order

_tf_text_sentencepiece_tokenizer_op_create_counter = monitoring.Counter(
    "/nlx/api/python/sentencepiece_tokenizer_create_counter",
    "Counter for number of SentencepieceTokenizers created in Python.")


class _SentencepieceModelResource(tracking.TrackableResource):
    """Utility to track the model resource tensor (for SavedModel support)."""
    def __init__(self, model, name):
        super(_SentencepieceModelResource, self).__init__()
        self._model = model
        self._name = name
        _ = self.resource_handle  # Accessing this property creates the resource.

    def _create_resource(self):
algorithm.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_tensor

from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
gen_constrained_sequence_op = load_library.load_op_library(
    resource_loader.get_path_to_datafile('_constrained_sequence_op.so'))


def viterbi_constrained_sequence(scores,
                                 sequence_length=None,
                                 allowed_transitions=None,
                                 transition_weights=None,
                                 use_log_space=False,
                                 use_start_and_end_states=True,
                                 name=None):
    """Performs greedy constrained sequence on a batch of examples.

  Constrains a set of predictions based on a set of legal transitions
  and/or a set of transition weights, returning the legal sequence that
  maximizes the product of the state scores and the transition weights
  according to the Viterbi algorithm. If use_log_space is True, the Viterbi
Пример #42
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
import sys

from tensorflow.python.platform import resource_loader

# Schema to use for flatbuffers
_SCHEMA = "third_party/tensorflow/lite/schema/schema.fbs"

# TODO(angerson): fix later when rules are simplified..
_SCHEMA = resource_loader.get_path_to_datafile("../schema/schema.fbs")
_BINARY = resource_loader.get_path_to_datafile("../../../flatbuffers/flatc")
# Account for different package positioning internal vs. external.
if not os.path.exists(_BINARY):
    _BINARY = resource_loader.get_path_to_datafile(
        "../../../../flatbuffers/flatc")

if not os.path.exists(_SCHEMA):
    raise RuntimeError("Sorry, schema file cannot be found at %r" % _SCHEMA)
if not os.path.exists(_BINARY):
    raise RuntimeError("Sorry, flatc is not available at %r" % _BINARY)

# A CSS description for making the visualizer
_CSS = """
<html>
<head>
Пример #43
0
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import traverse_tree_v4
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_predictions_v4
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_size
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import update_model_v4
# pylint: enable=unused-import

from tensorflow.contrib.util import loader
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import resources
from tensorflow.python.platform import resource_loader
from tensorflow.python.training import saver
from tensorflow.python.training.tracking import tracking

_model_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile("_model_ops.so"))

ops.NotDifferentiable("TreeVariable")
ops.NotDifferentiable("TreeSerialize")
ops.NotDifferentiable("TreeDeserialize")
ops.NotDifferentiable("TreeSize")
ops.NotDifferentiable("TreePredictionsV4")
ops.NotDifferentiable("FeatureUsageCounts")


class TreeVariableSavable(saver.BaseSaverBuilder.SaveableObject):
    """SaveableObject implementation for TreeVariable."""
    def __init__(self, params, tree_handle, stats_handle, create_op, name):
        """Creates a TreeVariableSavable object.

    Args:
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tensorflow op performing fused conv2d bias_add and relu."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op
from tensorflow.contrib.util import loader
from tensorflow.python.platform import resource_loader

_fused_conv2d_bias_activation_op_so = loader.load_op_library(
    resource_loader.get_path_to_datafile(
        "_fused_conv2d_bias_activation_op.so"))


# pylint: disable=redefined-builtin
def fused_conv2d_bias_activation(conv_input,
                                 filter,
                                 bias,
                                 strides=None,
                                 padding=None,
                                 conv_input_scale=1.0,
                                 side_input_scale=0.0,
                                 side_input=None,
                                 activation_mode="Relu",
                                 data_format=None,
                                 filter_format=None,
                                 name=None):
Пример #45
0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import ops
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader

dot_based_interact_ops = load_library.load_op_library(
    resource_loader.get_path_to_datafile('_dot_based_interact_ops.so'))
dot_based_interact = dot_based_interact_ops.dot_based_interact


@ops.RegisterGradient("DotBasedInteract")
def dot_based_interact_grad(op, grad):
    input = op.inputs[0]
    return dot_based_interact_ops.dot_based_interact_grad(input, grad)
Пример #46
0
import abc

from tensorflow.contrib.rnn.python.ops import fused_rnn_cell
from tensorflow.contrib.util import loader
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import resource_loader

_lstm_ops_so = loader.load_op_library(
    resource_loader.get_path_to_datafile("_lstm_ops.so"))


# pylint: disable=invalid-name
def _lstm_block_cell(x,
                     cs_prev,
                     h_prev,
                     w,
                     b,
                     wci=None,
                     wcf=None,
                     wco=None,
                     forget_bias=None,
                     cell_clip=None,
                     use_peephole=None,
                     name=None):
Пример #47
0
# Copyright 2018 The Sonnet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Use feature hash ops in python."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader

feature_hash_ops = load_library.load_op_library(
    resource_loader.get_path_to_datafile('_feature_hash_ops.so'))
feature_hash_op = feature_hash_ops.feature_hash_op
Пример #48
0
 def testThreads_NegativeValue(self):
     with self.assertRaisesRegexp(ValueError, 'num_threads should >= 1'):
         interpreter_wrapper.Interpreter(
             model_path=resource_loader.get_path_to_datafile(
                 'testdata/permute_float.tflite'),
             num_threads=-1)
Пример #49
0
def _load_ctypes_dll(name):
    filename = resource_loader.get_path_to_datafile(name)
    return ctypes.CDLL(filename, mode=ctypes.RTLD_GLOBAL)
Пример #50
0
# limitations under the License.
# ==============================================================================
"""Wrappers for sparse cross operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import load_library
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import resource_loader

_sparse_feature_cross_op = load_library.load_op_library(
    resource_loader.get_path_to_datafile("_sparse_feature_cross_op.so"))
assert _sparse_feature_cross_op, "Could not load _sparse_feature_cross_op.so."


def sparse_feature_cross(inputs,
                         hashed_output=False,
                         num_buckets=0,
                         name=None):
    """Crosses a list of Tensor or SparseTensor objects.

  See sparse_feature_cross_kernel.cc for more details.

  Args:
    inputs: List of `SparseTensor` or `Tensor` to be crossed.
    hashed_output: If true, returns the hash of the cross instead of the string.
      This will allow us avoiding string manipulations.
import tensorflow as tf

from tensorflow.python.platform import resource_loader
from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow_lite_support.metadata.python import metadata as _metadata
from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info
from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer
from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils

_MODEL = "../testdata/mobilenet_v2_1.0_224_quant.tflite"
_MULTI_INPUTS_MODEL = "../testdata/question_answerer/mobilebert_float.tflite"
_MULTI_OUTPUTS_MODEL = "../testdata/audio_classifier/two_heads.tflite"
_MODEL_NAME = "mobilenet_v2_1.0_224_quant"
_INPUT_NAME = "image"
_OUTPUT_NAME = "probability"
_LABEL_FILE = resource_loader.get_path_to_datafile("../testdata/labels.txt")
_EXPECTED_DUMMY_JSON = "../testdata/mobilenet_v2_1.0_224_quant_dummy.json"
_EXPECTED_META_INFO_JSON = "../testdata/mobilenet_v2_1.0_224_quant_meta_info_.json"
_EXPECTED_DEFAULT_JSON = "../testdata/mobilenet_v2_1.0_224_quant_default.json"
# Before populated into the model, metadata does not have the verson string
_EXPECTED_DUMMY_NO_VERSION_JSON = "../testdata/mobilenet_v2_1.0_224_quant_dummy_no_version.json"
_EXPECTED_MULTI_INPUTS_JSON = "../testdata/multi_inputs.json"
_EXPECTED_MULTI_OUTPUTS_JSON = "../testdata/multi_outputs.json"


class MetadataWriterTest(tf.test.TestCase):

  def test_populate_from_metadata_should_succeed(self):
    model_buffer = test_utils.load_file(_MODEL)
    model_metadata, input_metadata, output_metadata = (
        self._create_dummy_metadata())
Пример #52
0
from tensorflow.python.util.tf_export import tf_export as _tf_export

# Lazy load since some of the performance benchmark skylark rules
# break dependencies.
_toco_python = LazyLoader(
    "tensorflow_wrap_toco", globals(), "tensorflow.lite.toco.python."
    "tensorflow_wrap_toco")
del LazyLoader

# Find the toco_from_protos binary using the resource loader if using from
# bazel, otherwise we are in a pip where console_scripts already has
# the toco_from_protos tool.
if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
    _toco_from_proto_bin = ""
else:
    _toco_from_proto_bin = _resource_loader.get_path_to_datafile(
        "../toco/python/toco_from_protos")

if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin):
    _toco_from_proto_bin = "toco_from_protos"


def _try_convert_to_unicode(output):
    if output is None:
        return u""

    if isinstance(output, bytes):
        try:
            return output.decode()
        except UnicodeDecodeError:
            pass
    return output
Пример #53
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
from tensorflow_text.python.ops.tokenization import TokenizerWithOffsets

# pylint: disable=g-bad-import-order
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
gen_split_merge_tokenizer = load_library.load_op_library(resource_loader.get_path_to_datafile('_split_merge_tokenizer.so'))


class SplitMergeTokenizer(TokenizerWithOffsets):
  """Tokenizes a tensor of UTF-8 string into words according to labels."""

  def tokenize(self,
               input,  # pylint: disable=redefined-builtin
               labels,
               force_split_at_break_character=True):
    """Tokenizes a tensor of UTF-8 strings according to labels.

    ### Example:
    ```python
    >>> strings = ["HelloMonday", "DearFriday"],
    >>> labels = [[0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1],
Пример #54
0
import collections
import numbers

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

# pylint: disable=wildcard-import,undefined-variable
from tensorflow.contrib.factorization.python.ops.gen_factorization_ops import *
# pylint: enable=wildcard-import
from tensorflow.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.platform import resource_loader

_factorization_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile("_factorization_ops.so"))


class WALSModel(object):
    r"""A model for Weighted Alternating Least Squares matrix factorization.

  It minimizes the following loss function over U, V:
   \\( ||W \odot (A - U V^T) ||_F^2 + \lambda (||U||_F^2 + ||V||_F^2) )\\
    where,
    A: input matrix,
    W: weight matrix,
    U, V: row_factors and column_factors matrices,
    \\(\lambda)\\: regularization.
  Also we assume that W is of the following special form:
  \\( W_{ij} = W_0 + R_i * C_j )\\  if \\(A_{ij} \ne 0)\\,
  \\(W_{ij} = W_0)\\ otherwise.
Пример #55
0
"""TensorFlow dataset for Riegeli/records files."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import distutils.version
from riegeli.tensorflow.ops import gen_riegeli_dataset_ops
import tensorflow as tf

from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader

load_library.load_op_library(
    resource_loader.get_path_to_datafile('_riegeli_dataset_ops.so'))

__all__ = ('RiegeliDataset', )


class RiegeliDataset(dataset_ops.DatasetSource):
    """A `Dataset` comprising records from one or more Riegeli/records files."""

    __slots__ = ('_filenames', )

    def __init__(self, filenames):
        """Creates a `RiegeliDataset`.

    Args:
      filenames: A `tf.string` tensor containing one or more filenames.
    """
Пример #56
0
from tensorflow.python.framework.load_library import load_op_library
from tensorflow.python.framework.ops import convert_to_tensor
from tensorflow.python.framework.ops import name_scope
from tensorflow.python.framework.ops import op_scope
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as var_ops
from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits
from tensorflow.python.platform import resource_loader

__all__ = ['SdcaModel']

_sdca_ops = load_op_library(
    resource_loader.get_path_to_datafile('_sdca_ops.so'))
assert _sdca_ops, 'Could not load _sdca_ops.so'


# TODO(sibyl-Aix6ihai): add op_scope to appropriate methods.
class SdcaModel(object):
    """Stochastic dual coordinate ascent solver for linear models.

    This class currently only supports a single machine (multi-threaded)
    implementation. We expect the weights and duals to fit in a single machine.

    Loss functions supported:
     * Binary logistic loss
     * Squared loss
     * Hinge loss
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.embedding_ops import embedding_lookup
from tensorflow.python.platform import resource_loader

_clustering_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile('_clustering_ops.so'))

# Euclidean distance between vectors U and V is defined as ||U - V||_F which is
# the square root of the sum of the absolute squares of the elements difference.
SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean'
# Cosine distance between vectors U and V is defined as
# 1 - (U \dot V) / (||U||_F ||V||_F)
COSINE_DISTANCE = 'cosine'

RANDOM_INIT = 'random'
KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus'


class KMeans(object):
    """Creates the graph for k-means clustering."""
    def __init__(self,
Пример #58
0
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Exposes the Python wrapper of TOPTEngineOp."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import platform

if platform.system() != "Windows":
  # pylint: disable=wildcard-import,unused-import,g-import-not-at-top
  from tensorflow.contrib.tensoropt.ops.gen_topt_engine_op import *

  from tensorflow.contrib.util import loader
  from tensorflow.python.platform import resource_loader
  # pylint: enable=wildcard-import,unused-import,g-import-not-at-top

  _topt_engine_op = loader.load_op_library(
      resource_loader.get_path_to_datafile("_topt_engine_op.so"))
else:
  raise RuntimeError("Windows platforms are not supported")
Пример #59
0
from __future__ import print_function

import copy
import os
import shutil
import tempfile
import warnings
import zipfile

from flatbuffers.python import flatbuffers
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
from tensorflow.lite.experimental.support.metadata.flatbuffers_lib import _pywrap_flatbuffers
from tensorflow.python.platform import resource_loader

_FLATC_TFLITE_METADATA_SCHEMA_FILE = resource_loader.get_path_to_datafile(
    "metadata_schema.fbs")


# TODO(b/141467403): add delete method for associated files.
class MetadataPopulator(object):
    """Packs metadata and associated files into TensorFlow Lite model file.

  MetadataPopulator can be used to populate metadata and model associated files
  into a model file or a model buffer (in bytearray). It can also help to
  inspect list of files that have been packed into the model or are supposed to
  be packed into the model.

  The metadata file (or buffer) should be generated based on the metadata
  schema:
  third_party/tensorflow/lite/schema/metadata_schema.fbs
Пример #60
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random

from tensorflow.contrib.input_pipeline.ops import gen_input_pipeline_ops
from tensorflow.contrib.util import loader
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import resource_loader

_input_pipeline_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile("_input_pipeline_ops.so"))


def obtain_next(string_list_tensor, counter):
    """Basic wrapper for the ObtainNextOp.

  Args:
    string_list_tensor: A tensor that is a list of strings
    counter: an int64 ref tensor to keep track of which element is returned.

  Returns:
    An op that produces the element at counter + 1 in the list, round
    robin style.
  """
    return gen_input_pipeline_ops.obtain_next(string_list_tensor, counter)