def smart_constant_value(pred): """Return the bool value for `pred`, or None if `pred` had a dynamic value. Arguments: pred: A scalar, either a Python bool or tensor. Returns: True or False if `pred` has a constant boolean value, None otherwise. Raises: TypeError: If `pred` is not a Tensor or bool. """ if isinstance(pred, ops.Tensor): pred_value = tensor_util.constant_value(pred) # TODO(skyewm): consider folding this into tensor_util.constant_value. # pylint: disable=protected-access if pred_value is None: pred_value = c_api.TF_TryEvaluateConstant_wrapper( pred.graph._c_graph, pred._as_tf_output()) # pylint: enable=protected-access elif pred in {0, 1}: # Accept 1/0 as valid boolean values pred_value = bool(pred) elif isinstance(pred, bool): pred_value = pred else: raise TypeError( "`pred` must be a Tensor, or a Python bool, or 1 or 0. " "Found instead: %s" % type(pred)) return pred_value
def smart_constant_value(pred): """Return the bool value for `pred`, or None if `pred` had a dynamic value. Arguments: pred: A scalar, either a Python bool or tensor. Returns: True or False if `pred` has a constant boolean value, None otherwise. Raises: TypeError: If `pred` is not a Tensor or bool. """ if pred in {0, 1}: # Accept 1/0 as valid boolean values pred_value = bool(pred) elif isinstance(pred, bool): pred_value = pred elif isinstance(pred, ops.Tensor): pred_value = tensor_util.constant_value(pred) # TODO(skyewm): consider folding this into tensor_util.constant_value when # _USE_C_API is removed (there may be performance and correctness bugs, so I # wanted to limit the change hidden behind _USE_C_API). # pylint: disable=protected-access if pred_value is None and ops._USE_C_API: with errors.raise_exception_on_not_ok_status() as status: pred_value = c_api.TF_TryEvaluateConstant_wrapper( pred.graph._c_graph, pred._as_tf_output(), status) # pylint: enable=protected-access else: raise TypeError( "`pred` must be a Tensor, or a Python bool, or 1 or 0. " "Found instead: %s" % pred) return pred_value
def _TransposeGrad(op, grad): """Returns unshuffle(grad).""" p = op.inputs[1] if not context.executing_eagerly(): p_static = pywrap_tensorflow.TF_TryEvaluateConstant_wrapper( p.graph._c_graph, p._as_tf_output()) # pylint: disable=protected-access if p_static is not None: p = constant_op.constant(p_static, dtype=p.dtype) return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]
def _get_static_value(pred): """Helper function for getting static values from maybe-tensor objects.""" if tf.is_tensor(pred): pred_value = tf.get_static_value(tf.convert_to_tensor(pred)) # TODO(jamieas): remove the dependency on `pywrap_tensorflow`. # pylint: disable=protected-access if pred_value is None: pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph, pred._as_tf_output()) # pylint: enable=protected-access return pred_value return pred
def _BroadcastToGrad(op, grad): input_value = op.inputs[0] broadcast_shape = op.inputs[1] input_value_shape = array_ops.shape(input_value) if not context.executing_eagerly(): broadcast_shape_static = tensor_shape.TensorShape( pywrap_tensorflow.TF_TryEvaluateConstant_wrapper( broadcast_shape.graph._c_graph, broadcast_shape._as_tf_output())) # pylint: disable=protected-access if broadcast_shape_static.is_fully_defined(): broadcast_shape = constant_op.constant( broadcast_shape_static.as_list(), dtype=dtypes.int32) _, reduction_axes = gen_array_ops.broadcast_gradient_args( broadcast_shape, input_value_shape) updates_grad_reshaped = math_ops.reduce_sum( grad, axis=reduction_axes, keepdims=True) updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape) return [updates_grad, None]
def _get_static_predicate(pred): """Helper function for statically evaluating predicates in `cond`.""" if pred in {0, 1}: # Accept 1/0 as valid boolean values pred_value = bool(pred) elif isinstance(pred, bool): pred_value = pred elif isinstance(pred, tf.Tensor): pred_value = tf.get_static_value(pred) # TODO(jamieas): remove the dependency on `pywrap_tensorflow`. # pylint: disable=protected-access if pred_value is None: pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph, pred._as_tf_output()) # pylint: enable=protected-access else: raise TypeError("`pred` must be a Tensor, or a Python bool, or 1 or 0. " "Found instead: %s" % pred) return pred_value
def _get_static_predicate(pred): """Helper function for statically evaluating predicates in `cond`.""" if tf.is_tensor(pred): pred_value = tf.get_static_value(tf.convert_to_tensor(pred)) # TODO(jamieas): remove the dependency on `pywrap_tensorflow`. # pylint: disable=protected-access if pred_value is None: pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph, pred._as_tf_output()) # pylint: enable=protected-access if pred_value in (0, 1, True, False): pred_value = bool(pred_value) elif pred in (0, 1, True, False): # Accept 1/0 as valid boolean values # This branch also casts np.array(False), tf.EagerTensor(True), etc. pred_value = bool(pred) else: raise TypeError('`pred` must be a Tensor, or a Python bool, or 1 or 0. ' 'Found instead: {}'.format(pred)) return pred_value