def testWeightedMovingAverageBfloat16(self): bfloat16 = pywrap_tensorflow.TF_bfloat16_type() with self.cached_session() as sess: decay = 0.5 weight = array_ops.placeholder(dtypes.bfloat16, []) val = array_ops.placeholder(dtypes.bfloat16, []) wma = moving_averages.weighted_moving_average(val, decay, weight) variables.global_variables_initializer().run() # Get the first weighted moving average. val_1 = 3.0 weight_1 = 4.0 wma_array = sess.run(wma, feed_dict={val: val_1, weight: weight_1}) numerator_1 = val_1 * weight_1 * (1.0 - decay) denominator_1 = weight_1 * (1.0 - decay) self.assertAllClose(numerator_1 / denominator_1, wma_array) # Get the second weighted moving average. val_2 = 11.0 weight_2 = 22.0 wma_array = sess.run(wma, feed_dict={val: val_2, weight: weight_2}) numerator_2 = numerator_1 * decay + val_2 * weight_2 * (1.0 - decay) denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay) self.assertAllClose(bfloat16(numerator_2 / denominator_2), wma_array)
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Library of dtypes (Tensor element types).""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from tensorflow.core.framework import types_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.util.tf_export import tf_export _np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type() @tf_export("DType") class DType(object): """Represents the type of the elements in a `Tensor`. The following `DType` objects are defined: * `tf.float16`: 16-bit half-precision floating-point. * `tf.float32`: 32-bit single-precision floating-point. * `tf.float64`: 64-bit double-precision floating-point. * `tf.bfloat16`: 16-bit truncated floating-point. * `tf.complex64`: 64-bit single-precision complex. * `tf.complex128`: 128-bit double-precision complex. * `tf.int8`: 8-bit signed integer.