示例#1
0
def Load():
    """Load training ops library and return the loaded module."""
    with _ops_lock:
        global _training_ops
        if not _training_ops:
            ops_path = resource_loader.get_path_to_datafile(TRAINING_OPS_FILE)
            logging.info('data path: %s', ops_path)
            _training_ops = loader.load_op_library(ops_path)

            assert _training_ops, 'Could not load _training_ops.so'
    return _training_ops
示例#2
0
"""Python wrapper for input_pipeline_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random

from astronet.contrib.input_pipeline.ops import gen_input_pipeline_ops
from astronet.contrib.util import loader
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import resource_loader

_input_pipeline_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile("_input_pipeline_ops.so"))


def obtain_next(string_list_tensor, counter):
    """Basic wrapper for the ObtainNextOp.

  Args:
    string_list_tensor: A tensor that is a list of strings
    counter: an int64 ref tensor to keep track of which element is returned.

  Returns:
    An op that produces the element at counter + 1 in the list, round
    robin style.
  """
    return gen_input_pipeline_ops.obtain_next(string_list_tensor, counter)
示例#3
0
import csv

from astronet.contrib import lookup
from astronet.contrib.text.python.ops import gen_skip_gram_ops
from astronet.contrib.util import loader
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import resource_loader
from tensorflow.python.training import input as input_ops

_checkpoint_ops_so = loader.load_op_library(
    resource_loader.get_path_to_datafile("_skip_gram_ops.so"))

ops.NotDifferentiable("SkipGramGenerateCandidates")


def skip_gram_sample(input_tensor,
                     min_skips=1,
                     max_skips=5,
                     start=0,
                     limit=-1,
                     emit_self_as_target=False,
                     vocab_freq_table=None,
                     vocab_min_count=None,
                     vocab_subsampling=None,
                     corpus_size=None,
                     batch_size=None,
示例#4
0
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam Search helper ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.seq2seq.ops import gen_beam_search_ops
from astronet.contrib.util import loader
from tensorflow.python.platform import resource_loader

_beam_search_ops_so = loader.load_op_library(
    resource_loader.get_path_to_datafile("_beam_search_ops.so"))

gather_tree = gen_beam_search_ops.gather_tree
示例#5
0
# ==============================================================================
"""Wrappers for sparse cross operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.framework import deprecated_arg_values
from astronet.contrib.layers.ops import gen_sparse_feature_cross_op
from astronet.contrib.util import loader
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import resource_loader

_sparse_feature_cross_op = loader.load_op_library(
    resource_loader.get_path_to_datafile("_sparse_feature_cross_op.so"))

# Default hash key for the FingerprintCat64.
SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY = 0xDECAFCAFFE


@deprecated_arg_values(
    "2016-11-20",
    "The default behavior of sparse_feature_cross is changing, the default\n"
    "value for hash_key will change to SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY.\n"
    "From that point on sparse_feature_cross will always use FingerprintCat64\n"
    "to concatenate the feature fingerprints. And the underlying\n"
    "_sparse_feature_cross_op.sparse_feature_cross operation will be marked\n"
    "as deprecated.",
    hash_key=None)
def sparse_feature_cross(inputs,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tensorflow op performing fused conv2d bias_add and relu."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op
from astronet.contrib.util import loader
from tensorflow.python.platform import resource_loader

_fused_conv2d_bias_activation_op_so = loader.load_op_library(
    resource_loader.get_path_to_datafile("_fused_conv2d_bias_activation_op.so"))


# pylint: disable=redefined-builtin
def fused_conv2d_bias_activation(conv_input,
                                 filter,
                                 bias,
                                 strides=None,
                                 padding=None,
                                 conv_input_scale=1.0,
                                 side_input_scale=0.0,
                                 side_input=None,
                                 activation_mode="Relu",
                                 data_format=None,
                                 filter_format=None,
                                 name=None):
示例#7
0
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python helper for loading IGFS ops and kernels."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.util import loader
from tensorflow.python.platform import resource_loader

_dataset_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile("../../_ignite_ops.so"))
示例#8
0
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads the _boosted_trees_ops.so when the binary is not statically linked."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.util import loader
from tensorflow.python.framework import errors
from tensorflow.python.platform import resource_loader

# Conditionally load ops, they might already be statically linked in.
try:
    loader.load_op_library(
        resource_loader.get_path_to_datafile('_boosted_trees_ops.so'))
except (errors.NotFoundError, IOError):
    print('Error loading _boosted_trees_ops.so')
示例#9
0
# limitations under the License.
# =============================================================================
"""Inter-process communication using MPI."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from astronet.contrib.mpi_collectives.ops import gen_mpi_ops
from astronet.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader

_mpi_ops_so = loader.load_op_library(
    resource_loader.get_path_to_datafile('_mpi_ops.so'))


def size(name=None):
    """An op which returns the number of MPI processes.

  This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the
  size of the global communicator.

  Returns:
    An integer scalar containing the number of MPI processes.
  """
    return gen_mpi_ops.mpi_size(name=name)


ops.NotDifferentiable('MPISize')
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import resource_loader
from tensorflow.python.util.compat import collections_abc

_factorization_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile("_factorization_ops.so"))


class WALSModel(object):
    r"""A model for Weighted Alternating Least Squares matrix factorization.

  It minimizes the following loss function over U, V:
  $$
   \|\sqrt W \odot (A - U V^T)\|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2)
  $$
    where,
    A: input matrix,
    W: weight matrix. Note that the (element-wise) square root of the weights
      is used in the objective function.
    U, V: row_factors and column_factors matrices,
    \\(\lambda)\\: regularization.
示例#11
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrappers for nearest neighbor operations."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader

_nearest_neighbor_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile("_nearest_neighbor_ops.so"))


def hyperplane_lsh_probes(point_hyperplane_product,
                          num_tables,
                          num_hyperplanes_per_table,
                          num_probes,
                          name=None):
    """Computes probes for the hyperplane hash.

  The op supports multiprobing, i.e., the number of requested probes can be
  larger than the number of tables. In that case, the same table can be probed
  multiple times.

  The first `num_tables` probes are always the primary hashes for each table.
# limitations under the License.
# ==============================================================================
"""Python layer for distort_image_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.image.ops import gen_distort_image_ops
from astronet.contrib.util import loader
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import resource_loader

_distort_image_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile('_distort_image_ops.so'))


# pylint: disable=invalid-name
def random_hsv_in_yiq(image,
                      max_delta_hue=0,
                      lower_saturation=1,
                      upper_saturation=1,
                      lower_value=1,
                      upper_value=1,
                      seed=None):
    """Adjust hue, saturation, value of an RGB image randomly in YIQ color space.

  Equivalent to `adjust_yiq_hsv()` but uses a `delta_h` randomly
  picked in the interval `[-max_delta_hue, max_delta_hue]`, a `scale_saturation`
  randomly picked in the interval `[lower_saturation, upper_saturation]`, and
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python layer for image_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.image.ops import gen_single_image_random_dot_stereograms_ops
from astronet.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader

_sirds_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile(
        "_single_image_random_dot_stereograms.so"))


def single_image_random_dot_stereograms(depth_values,
                                        hidden_surface_removal=None,
                                        convergence_dots_size=None,
                                        dots_per_inch=None,
                                        eye_separation=None,
                                        mu=None,
                                        normalize=None,
                                        normalize_max=None,
                                        normalize_min=None,
                                        border_level=None,
                                        number_colors=None,
                                        output_image_shape=None,
示例#14
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for memory statistics."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.memory_stats.ops import gen_memory_stats_ops
from astronet.contrib.util import loader
from tensorflow.python.platform import resource_loader

_memory_stats_ops_so = loader.load_op_library(
    resource_loader.get_path_to_datafile("_memory_stats_ops.so"))


def BytesInUse():
    """Generates an op that computes the current memory of a device."""
    return gen_memory_stats_ops.bytes_in_use()


def BytesLimit():
    """Generates an op that measures the total memory (in bytes) of a device."""
    return gen_memory_stats_ops.bytes_limit()


def MaxBytesInUse():
    """Generates an op that computes the peak memory of a device."""
    return gen_memory_stats_ops.max_bytes_in_use()
示例#15
0
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Custom ops used by tensorforest."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# go/tf-wildcard-import
# pylint: disable=wildcard-import
from astronet.contrib.tensor_forest.python.ops.gen_tensor_forest_ops import *
# pylint: enable=wildcard-import
from astronet.contrib.util import loader
from tensorflow.python.platform import resource_loader

_tensor_forest_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile('_tensor_forest_ops.so'))
示例#16
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Libsvm decoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.libsvm.ops import gen_libsvm_ops
from astronet.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import resource_loader
from tensorflow.python.util.deprecation import deprecated

_libsvm_ops_so = loader.load_op_library(
    resource_loader.get_path_to_datafile("_libsvm_ops.so"))


@deprecated(None,
            'tf.contrib.libsvm will be removed in 2.0, the support for libsvm '
            'format will continue to be provided in tensorflow-io: '
            'https://github.com/tensorflow/io')
def decode_libsvm(content, num_features, dtype=None, label_dtype=None):
    """Convert Libsvm records to a tensor of label and a tensor of feature.

  Args:
    content: A `Tensor` of type `string`. Each string is a record/row in
      the Libsvm format.
    num_features: The number of features.
    dtype: The type of the output feature tensor. Default to tf.float32.
    label_dtype: The type of the output label tensor. Default to tf.int64.
示例#17
0
from astronet.contrib.tensor_forest.python.ops.gen_model_ops import feature_usage_counts
from astronet.contrib.tensor_forest.python.ops.gen_model_ops import traverse_tree_v4
from astronet.contrib.tensor_forest.python.ops.gen_model_ops import tree_predictions_v4
from astronet.contrib.tensor_forest.python.ops.gen_model_ops import tree_size
from astronet.contrib.tensor_forest.python.ops.gen_model_ops import update_model_v4
# pylint: enable=unused-import

from astronet.contrib.util import loader
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import resources
from tensorflow.python.platform import resource_loader
from tensorflow.python.training import saver
from tensorflow.python.training.tracking import tracking

_model_ops = loader.load_op_library(
    resource_loader.get_path_to_datafile("_model_ops.so"))

ops.NotDifferentiable("TreeVariable")
ops.NotDifferentiable("TreeSerialize")
ops.NotDifferentiable("TreeDeserialize")
ops.NotDifferentiable("TreeSize")
ops.NotDifferentiable("TreePredictionsV4")
ops.NotDifferentiable("FeatureUsageCounts")


class TreeVariableSavable(saver.BaseSaverBuilder.SaveableObject):
    """SaveableObject implementation for TreeVariable."""
    def __init__(self, params, tree_handle, stats_handle, create_op, name):
        """Creates a TreeVariableSavable object.

    Args:
示例#18
0
# =============================================================================
"""Encoding and decoding audio using FFmpeg."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.ffmpeg.ops import gen_decode_audio_op_py
from astronet.contrib.ffmpeg.ops import gen_decode_video_op_py
from astronet.contrib.ffmpeg.ops import gen_encode_audio_op_py
from astronet.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
from tensorflow.python.util.deprecation import deprecated

_ffmpeg_so = loader.load_op_library(
    resource_loader.get_path_to_datafile('ffmpeg.so'))


@deprecated('2018-09-04',
            'tf.contrib.ffmpeg will be removed in 2.0, the support for video '
            'and audio will continue to be provided in tensorflow-io: '
            'https://github.com/tensorflow/io')
def decode_audio(contents,
                 file_format=None,
                 samples_per_second=None,
                 channel_count=None,
                 stream=None):
    """Create an op that decodes the contents of an audio file.

  Note that ffmpeg is free to select the "best" audio track from an mp4.
  https://trac.ffmpeg.org/wiki/Map
示例#19
0
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=unused-import
from astronet.contrib.periodic_resample.python.ops import gen_periodic_resample_op

from astronet.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample, periodic_resample_op_grad

from astronet.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
# pylint: enable=unused-import

_periodic_resample_op = loader.load_op_library(
    resource_loader.get_path_to_datafile('_periodic_resample_op.so'))

@ops.RegisterGradient("PeriodicResample")
def _periodic_resample_grad_cc(op, grad):
  return periodic_resample_op_grad(
      grad, op.inputs[0].shape, op.get_attr('shape'))
from __future__ import print_function

from six import iteritems
from six import string_types

from astronet.contrib.bigtable.ops import gen_bigtable_ops
from astronet.contrib.util import loader
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.platform import resource_loader

_bigtable_so = loader.load_op_library(
    resource_loader.get_path_to_datafile("_bigtable.so"))


class BigtableClient(object):
    """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF.

  BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the
  `table` method to open a Bigtable table.
  """
    def __init__(self,
                 project_id,
                 instance_id,
                 connection_pool_size=None,
                 max_receive_message_size=None):
        """Creates a BigtableClient that can be used to open connections to tables.
示例#21
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Tensorflow op performing differentiable resampling."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from astronet.contrib.resampler.ops import gen_resampler_ops
from astronet.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader

_resampler_so = loader.load_op_library(
    resource_loader.get_path_to_datafile("_resampler_ops.so"))


def resampler(data, warp, name="resampler"):
  """Resamples input data at user defined coordinates.

  The resampler currently only supports bilinear interpolation of 2D data.

  Args:
    data: Tensor of shape `[batch_size, data_height, data_width,
      data_num_channels]` containing 2D data that will be resampled.
    warp: Tensor of minimum rank 2 containing the coordinates at which
      resampling will be performed. Since only bilinear interpolation is
      currently supported, the last dimension of the `warp` tensor must be 2,
      representing the (x, y) coordinate where x is the index for width and y is
      the index for height.