Ejemplo n.º 1
0
def smart_constant_value(pred):
    """Return the bool value for `pred`, or None if `pred` had a dynamic value.

  Arguments:
    pred: A scalar, either a Python bool or tensor.

  Returns:
    True or False if `pred` has a constant boolean value, None otherwise.

  Raises:
    TypeError: If `pred` is not a Tensor or bool.
  """
    if isinstance(pred, ops.Tensor):
        pred_value = tensor_util.constant_value(pred)
        # TODO(skyewm): consider folding this into tensor_util.constant_value.
        # pylint: disable=protected-access
        if pred_value is None:
            pred_value = c_api.TF_TryEvaluateConstant_wrapper(
                pred.graph._c_graph, pred._as_tf_output())
        # pylint: enable=protected-access
    elif pred in {0, 1}:  # Accept 1/0 as valid boolean values
        pred_value = bool(pred)
    elif isinstance(pred, bool):
        pred_value = pred
    else:
        raise TypeError(
            "`pred` must be a Tensor, or a Python bool, or 1 or 0. "
            "Found instead: %s" % type(pred))

    return pred_value
Ejemplo n.º 2
0
def _get_static_value(pred):
  """Helper function for getting static values from maybe-tensor objects."""
  if tf.is_tensor(pred):
    pred_value = tf.get_static_value(tf.convert_to_tensor(pred))

    # TODO(jamieas): remove the dependency on `pywrap_tensorflow`.
    # pylint: disable=protected-access
    if not JAX_MODE and pred_value is None:
      pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
                                                        pred._as_tf_output())
    # pylint: enable=protected-access
    return pred_value
  return pred
Ejemplo n.º 3
0
def try_evaluate_constant(tensor):  # pylint: disable=invalid-name
    """Evaluates a symbolic tensor as a constant.

  Args:
    tensor: a symbolic Tensor.

  Returns:
    ndarray if the evaluation succeeds, or None if it fails.
  """
    # pylint: disable=protected-access
    with tensor.graph._c_graph.get() as c_graph:
        return c_api.TF_TryEvaluateConstant_wrapper(c_graph,
                                                    tensor._as_tf_output())
Ejemplo n.º 4
0
def _BroadcastToGrad(op, grad):
  input_value = op.inputs[0]
  broadcast_shape = op.inputs[1]
  input_value_shape = array_ops.shape(input_value)
  if not isinstance(broadcast_shape, ops.EagerTensor):
    broadcast_shape_static = tensor_shape.TensorShape(
        pywrap_tf_session.TF_TryEvaluateConstant_wrapper(
            broadcast_shape.graph._c_graph, broadcast_shape._as_tf_output()))  # pylint: disable=protected-access
    if broadcast_shape_static.is_fully_defined():
      broadcast_shape = constant_op.constant(
          broadcast_shape_static.as_list(), dtype=dtypes.int32)
  _, reduction_axes = gen_array_ops.broadcast_gradient_args(
      broadcast_shape, input_value_shape)
  updates_grad_reshaped = math_ops.reduce_sum(
      grad, axis=reduction_axes, keepdims=True)
  updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape)
  return [updates_grad, None]
Ejemplo n.º 5
0
def _get_static_value(pred):
    """Helper function for getting static values from maybe-tensor objects."""
    if JAX_MODE:
        try:
            return np.asarray(pred)
        except:  # JAX sometimes raises raw Exception in __array__.  # pylint: disable=bare-except
            return None
    if tf.is_tensor(pred):
        pred_value = tf.get_static_value(tf.convert_to_tensor(pred))

        # TODO(jamieas): remove the dependency on `pywrap_tensorflow`.
        # pylint: disable=protected-access
        if pred_value is None:
            pred_value = c_api.TF_TryEvaluateConstant_wrapper(
                pred.graph._c_graph, pred._as_tf_output())
        # pylint: enable=protected-access
        return pred_value
    return pred
Ejemplo n.º 6
0
def _get_static_value(pred):
    """Helper function for getting static values from maybe-tensor objects."""
    if JAX_MODE:
        try:
            return np.asarray(pred)
        except:  # JAX sometimes raises raw Exception in __array__.  # pylint: disable=bare-except
            return None
    if tf.is_tensor(pred):
        pred_value = tf.get_static_value(tf.convert_to_tensor(pred))

        # Explicitly check for ops.Tensor, to avoid an AttributeError
        # when requesting `KerasTensor.graph`.
        if pred_value is None and isinstance(pred, ops.Tensor):
            if hasattr(tensor_util, 'try_evaluate_constant'):
                pred_value = tensor_util.try_evaluate_constant(pred)
            else:
                # TODO(feyu): remove this branch after try_evaluate_constant is in
                # tf-nightly.
                pred_value = c_api.TF_TryEvaluateConstant_wrapper(
                    pred.graph._c_graph, pred._as_tf_output())  # pylint: disable=protected-access
        return pred_value
    return pred