def smart_constant_value(pred): """Return the bool value for `pred`, or None if `pred` had a dynamic value. Arguments: pred: A scalar, either a Python bool or tensor. Returns: True or False if `pred` has a constant boolean value, None otherwise. Raises: TypeError: If `pred` is not a Tensor or bool. """ if isinstance(pred, ops.Tensor): pred_value = tensor_util.constant_value(pred) # TODO(skyewm): consider folding this into tensor_util.constant_value. # pylint: disable=protected-access if pred_value is None: pred_value = c_api.TF_TryEvaluateConstant_wrapper( pred.graph._c_graph, pred._as_tf_output()) # pylint: enable=protected-access elif pred in {0, 1}: # Accept 1/0 as valid boolean values pred_value = bool(pred) elif isinstance(pred, bool): pred_value = pred else: raise TypeError( "`pred` must be a Tensor, or a Python bool, or 1 or 0. " "Found instead: %s" % type(pred)) return pred_value
def _get_static_value(pred): """Helper function for getting static values from maybe-tensor objects.""" if tf.is_tensor(pred): pred_value = tf.get_static_value(tf.convert_to_tensor(pred)) # TODO(jamieas): remove the dependency on `pywrap_tensorflow`. # pylint: disable=protected-access if not JAX_MODE and pred_value is None: pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph, pred._as_tf_output()) # pylint: enable=protected-access return pred_value return pred
def try_evaluate_constant(tensor): # pylint: disable=invalid-name """Evaluates a symbolic tensor as a constant. Args: tensor: a symbolic Tensor. Returns: ndarray if the evaluation succeeds, or None if it fails. """ # pylint: disable=protected-access with tensor.graph._c_graph.get() as c_graph: return c_api.TF_TryEvaluateConstant_wrapper(c_graph, tensor._as_tf_output())
def _BroadcastToGrad(op, grad): input_value = op.inputs[0] broadcast_shape = op.inputs[1] input_value_shape = array_ops.shape(input_value) if not isinstance(broadcast_shape, ops.EagerTensor): broadcast_shape_static = tensor_shape.TensorShape( pywrap_tf_session.TF_TryEvaluateConstant_wrapper( broadcast_shape.graph._c_graph, broadcast_shape._as_tf_output())) # pylint: disable=protected-access if broadcast_shape_static.is_fully_defined(): broadcast_shape = constant_op.constant( broadcast_shape_static.as_list(), dtype=dtypes.int32) _, reduction_axes = gen_array_ops.broadcast_gradient_args( broadcast_shape, input_value_shape) updates_grad_reshaped = math_ops.reduce_sum( grad, axis=reduction_axes, keepdims=True) updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape) return [updates_grad, None]
def _get_static_value(pred): """Helper function for getting static values from maybe-tensor objects.""" if JAX_MODE: try: return np.asarray(pred) except: # JAX sometimes raises raw Exception in __array__. # pylint: disable=bare-except return None if tf.is_tensor(pred): pred_value = tf.get_static_value(tf.convert_to_tensor(pred)) # TODO(jamieas): remove the dependency on `pywrap_tensorflow`. # pylint: disable=protected-access if pred_value is None: pred_value = c_api.TF_TryEvaluateConstant_wrapper( pred.graph._c_graph, pred._as_tf_output()) # pylint: enable=protected-access return pred_value return pred
def _get_static_value(pred): """Helper function for getting static values from maybe-tensor objects.""" if JAX_MODE: try: return np.asarray(pred) except: # JAX sometimes raises raw Exception in __array__. # pylint: disable=bare-except return None if tf.is_tensor(pred): pred_value = tf.get_static_value(tf.convert_to_tensor(pred)) # Explicitly check for ops.Tensor, to avoid an AttributeError # when requesting `KerasTensor.graph`. if pred_value is None and isinstance(pred, ops.Tensor): if hasattr(tensor_util, 'try_evaluate_constant'): pred_value = tensor_util.try_evaluate_constant(pred) else: # TODO(feyu): remove this branch after try_evaluate_constant is in # tf-nightly. pred_value = c_api.TF_TryEvaluateConstant_wrapper( pred.graph._c_graph, pred._as_tf_output()) # pylint: disable=protected-access return pred_value return pred