def _truediv_adjust_dtypes(self, me, other): # TF truediv in Python3 produces float64 when both inputs are int32 or # int64. We want to avoid that when is_allow_float64() is False. if (not dtypes.is_allow_float64() and me.dtype == other.dtype and (me.dtype in (tf.int32, tf.int64))): me = tf.cast(me, dtype=tf.float32) other = tf.cast(other, dtype=tf.float32) return me, other
def f(x1, x2): if x1.dtype == tf.bool: assert x2.dtype == tf.bool float_ = dtypes.default_float_type() x1 = tf.cast(x1, float_) x2 = tf.cast(x2, float_) if not dtypes.is_allow_float64(): # tf.math.truediv in Python3 produces float64 when both inputs are int32 # or int64. We want to avoid that when is_allow_float64() is False. x1, x2 = _avoid_float64(x1, x2) return tf.math.truediv(x1, x2)