Ejemplo n.º 1
0
  def testWeightedMovingAverageBfloat16(self):
    bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
    with self.cached_session() as sess:
      decay = 0.5
      weight = array_ops.placeholder(dtypes.bfloat16, [])
      val = array_ops.placeholder(dtypes.bfloat16, [])

      wma = moving_averages.weighted_moving_average(val, decay, weight)
      variables.global_variables_initializer().run()

      # Get the first weighted moving average.
      val_1 = 3.0
      weight_1 = 4.0
      wma_array = sess.run(wma, feed_dict={val: val_1, weight: weight_1})
      numerator_1 = val_1 * weight_1 * (1.0 - decay)
      denominator_1 = weight_1 * (1.0 - decay)
      self.assertAllClose(numerator_1 / denominator_1, wma_array)

      # Get the second weighted moving average.
      val_2 = 11.0
      weight_2 = 22.0
      wma_array = sess.run(wma, feed_dict={val: val_2, weight: weight_2})
      numerator_2 = numerator_1 * decay + val_2 * weight_2 * (1.0 - decay)
      denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay)
      self.assertAllClose(bfloat16(numerator_2 / denominator_2), wma_array)
Ejemplo n.º 2
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library of dtypes (Tensor element types)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.core.framework import types_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.util.tf_export import tf_export

_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()


@tf_export("DType")
class DType(object):
    """Represents the type of the elements in a `Tensor`.

  The following `DType` objects are defined:

  * `tf.float16`: 16-bit half-precision floating-point.
  * `tf.float32`: 32-bit single-precision floating-point.
  * `tf.float64`: 64-bit double-precision floating-point.
  * `tf.bfloat16`: 16-bit truncated floating-point.
  * `tf.complex64`: 64-bit single-precision complex.
  * `tf.complex128`: 128-bit double-precision complex.
  * `tf.int8`: 8-bit signed integer.