Beispiel #1
0
 def test_checks(self):
     self.assertTrue(type_check.is_tensor(constant_op.constant([1, 2, 3])))
     self.assertTrue(
         type_check.is_tensor(test_util.variables.Variable([1, 2, 3])))
     self.assertTrue(
         type_check.is_tensor(
             test_util.array_ops.placeholder(test_util.dtypes.float32)))
     self.assertFalse(type_check.is_tensor(3))
     self.assertFalse(type_check.is_tensor(numpy.eye(3)))
 def test_checks(self):
   self.assertTrue(type_check.is_tensor(constant_op.constant([1, 2, 3])))
   self.assertTrue(
       type_check.is_tensor(test_util.variables.Variable([1, 2, 3])))
   self.assertTrue(
       type_check.is_tensor(
           test_util.array_ops.placeholder(test_util.dtypes.float32)))
   self.assertFalse(type_check.is_tensor(3))
   self.assertFalse(type_check.is_tensor(numpy.eye(3)))
Beispiel #3
0
def run_while(cond_fn, body_fn, init_args):
    """Type-dependent functional while loop.

  Args:
    cond_fn: A Python callable implementing the stop conditions of the loop.
    body_fn: A Python callable implementing the body of the loop.
    init_args: The initial values of the arguments that will be passed to both
      cond_fn and body_fn.

  Returns:
    result: A list of values with the same shape and type as init_args. If any
    of the init_args, or any variables closed-over in cond_fn are Tensors,
    tf.while_loop will be used, otherwise a Python while loop will be ran.

  Raises:
    ValueError: if init_args is not a tuple or list with one or more elements.
  """
    if not isinstance(init_args, (tuple, list)) or not init_args:
        raise ValueError(
            'init_args must be a non-empty list or tuple, found %s' %
            init_args)

    # TODO(alexbw): statically determine all active variables in cond_fn,
    # and pass them directly
    closure_vars = tuple(
        [c.cell_contents for c in six.get_function_closure(cond_fn) or []])
    possibly_tensors = tuple(init_args) + closure_vars
    if is_tensor(*possibly_tensors):
        return control_flow_ops.while_loop(cond_fn, body_fn, init_args)
    else:
        return py_while_loop(cond_fn, body_fn, init_args)
def run_while(cond_fn, body_fn, init_args):
  """Type-dependent functional while loop.

  Args:
    cond_fn: A Python callable implementing the stop conditions of the loop.
    body_fn: A Python callable implementing the body of the loop.
    init_args: The initial values of the arguments that will be passed to both
      cond_fn and body_fn.

  Returns:
    result: A list of values with the same shape and type as init_args. If any
    of the init_args, or any variables closed-over in cond_fn are Tensors,
    tf.while_loop will be used, otherwise a Python while loop will be ran.

  Raises:
    ValueError: if init_args is not a tuple or list with one or more elements.
  """
  if not isinstance(init_args, (tuple, list)) or not init_args:
    raise ValueError(
        'init_args must be a non-empty list or tuple, found %s' % init_args)

  # TODO(alexbw): statically determine all active variables in cond_fn,
  # and pass them directly
  closure_vars = tuple(
      [c.cell_contents for c in six.get_function_closure(cond_fn) or []])
  possibly_tensors = tuple(init_args) + closure_vars
  if is_tensor(*possibly_tensors):
    return control_flow_ops.while_loop(cond_fn, body_fn, init_args)
  else:
    return py_while_loop(cond_fn, body_fn, init_args)
Beispiel #5
0
def dynamic_range(start_or_stop, stop=None, step=None):
  """Implementation of range using dynamic dispatch."""
  if type_check.is_tensor(start_or_stop, stop, step):
    if step is not None:
      return math_ops.range(start_or_stop, stop, step)
    if stop is not None:
      return math_ops.range(start_or_stop, stop)
    return math_ops.range(start_or_stop)

  if step is not None:
    return range(start_or_stop, stop, step)
  elif stop is not None:
    return range(start_or_stop, stop)
  return range(start_or_stop)
def run_cond(condition, true_fn, false_fn):
  """Type-dependent functional conditional.

  Args:
    condition: A Tensor or Python bool.
    true_fn: A Python callable implementing the true branch of the conditional.
    false_fn: A Python callable implementing the false branch of the
      conditional.

  Returns:
    result: The result of calling the appropriate branch. If condition is a
    Tensor, tf.cond will be used. Otherwise, a standard Python if statement will
    be ran.
  """
  if is_tensor(condition):
    return control_flow_ops.cond(condition, true_fn, false_fn)
  else:
    return py_cond(condition, true_fn, false_fn)
Beispiel #7
0
def run_cond(condition, true_fn, false_fn):
    """Type-dependent functional conditional.

  Args:
    condition: A Tensor or Python bool.
    true_fn: A Python callable implementing the true branch of the conditional.
    false_fn: A Python callable implementing the false branch of the
      conditional.

  Returns:
    result: The result of calling the appropriate branch. If condition is a
    Tensor, tf.cond will be used. Otherwise, a standard Python if statement will
    be ran.
  """
    if is_tensor(condition):
        return control_flow_ops.cond(condition, true_fn, false_fn)
    else:
        return py_cond(condition, true_fn, false_fn)
Beispiel #8
0
def dynamic_is_not(left, right):
    if is_tensor(left, right):
        return math_ops.not_equal(left.name, right.name)
    else:
        return left is not right
Beispiel #9
0
def dynamic_is(left, right):
    if is_tensor(left, right):
        return math_ops.equal(left.name, right.name)
    else:
        return left is right
def dynamic_is_not(left, right):
  if is_tensor(left, right):
    return math_ops.not_equal(left.name, right.name)
  else:
    return left is not right
def dynamic_is(left, right):
  if is_tensor(left, right):
    return math_ops.equal(left.name, right.name)
  else:
    return left is right