Exemplo n.º 1
0
 def test_len(self):
   self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
   with self.test_session() as sess:
     t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
     self.assertEqual(t, 3)
     ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
     self.assertEqual(sess.run(ta), 5)
     tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
     self.assertEqual(sess.run(tl), 3)
Exemplo n.º 2
0
    def test_len_dynamic_shape(self):
        with self.cached_session() as sess:
            p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
            t = py_builtins.len_(p)
            self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)

            with self.assertRaises(errors_impl.InvalidArgumentError):
                t = py_builtins.len_(p)
                sess.run(t, {p: 1})
Exemplo n.º 3
0
 def test_len(self):
   self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
   with self.cached_session() as sess:
     t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
     self.assertEqual(t, 3)
     ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
     self.assertEqual(self.evaluate(ta), 5)
     tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
     self.assertEqual(self.evaluate(tl), 3)
Exemplo n.º 4
0
  def test_len_dynamic_shape(self):
    with self.test_session() as sess:
      p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
      t = py_builtins.len_(p)
      self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)

      with self.assertRaises(errors_impl.InvalidArgumentError):
        t = py_builtins.len_(p)
        sess.run(t, {p: 1})
Exemplo n.º 5
0
def _known_len_tf_for_stmt(iter_, extra_test, body, init_state):
  """Overload of for_stmt that iterates over objects that admit a length."""
  n = py_builtins.len_(iter_)

  def while_body(iterate_index, *state):
    iterate = iter_[iterate_index]
    new_state = body(iterate, *state)

    state = (iterate_index + 1,)
    if new_state:
      state += new_state

    return state

  def while_cond(iterate_index, *state):
    if extra_test is not None:
      return gen_math_ops.logical_and(iterate_index < n, extra_test(*state))
    return iterate_index < n

  results = _tf_while_stmt(
      while_cond,
      while_body,
      init_state=(0,) + init_state,
      opts=dict(maximum_iterations=n))

  # Dropping the iteration index because it's not syntactically visible.
  # TODO(mdan): Don't.
  if isinstance(results, (tuple, list)):
    assert len(results) >= 1  # Has at least the iterate.
    if len(results) > 1:
      results = results[1:]
  else:
    results = ()

  return results
Exemplo n.º 6
0
def _known_len_for_stmt(iter_, extra_test, body, init_state):
  """Overload of for_stmt that iterates over objects that admit a length."""
  n = py_builtins.len_(iter_)

  def while_body(iterate_index, *state):
    iterate = iter_[iterate_index]
    new_state = body(iterate, *state)
    return (iterate_index + 1,) + new_state

  def while_cond(iterate_index, *state):
    return gen_math_ops.logical_and(iterate_index < n, extra_test(*state))

  results = while_stmt(
      while_cond,
      while_body,
      init_state=(0,) + init_state,
      extra_deps=(iter_,),
      opts=dict(maximum_iterations=n))
  # Dropping the iteration index because it's not syntactically visible.
  results = results[1:]

  # TODO(mdan): Remove this special case.
  if len(results) == 1:
    return results[0]
  return results
Exemplo n.º 7
0
def _known_len_for_stmt(iter_, extra_test, body, init_state):
    """Overload of for_stmt that iterates over objects that admit a length."""
    n = py_builtins.len_(iter_)

    def while_body(iterate_index, *state):
        iterate = iter_[iterate_index]
        new_state = body(iterate, *state)
        if new_state:
            return (iterate_index + 1, ) + new_state
        else:
            return iterate_index + 1

    def while_cond(iterate_index, *state):
        return gen_math_ops.logical_and(iterate_index < n, extra_test(*state))

    results = while_stmt(while_cond,
                         while_body,
                         init_state=(0, ) + init_state,
                         extra_deps=(iter_, ),
                         opts=dict(maximum_iterations=n))

    # Dropping the iteration index because it's not syntactically visible.
    # TODO(mdan): Don't.
    if isinstance(results, (tuple, list)):
        assert len(results) >= 1  # Has at least the iterate.
        if len(results) > 1:
            results = results[1:]
        if len(results) == 1:
            # TODO(mdan): Remove this special case.
            results, = results
    else:
        results = ()

    return results
Exemplo n.º 8
0
def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
                           init_vars, basic_symbol_names,
                           composite_symbol_names, opts):
    """Overload of for_stmt that iterates over TF entities that admit a length."""
    _disallow_undefs_into_loop(*init_vars)

    n = py_builtins.len_(iter_)
    # TODO(b/117628877): Revisit performance once XLA has the necessary support.
    # Note: using a TensorArray creates an extra copy, but can calculate
    # gradients more efficiently than StridedSlice.
    ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
    iter_ = ta.unstack(iter_)

    def while_body(iterate_index, *loop_vars):
        """Main loop body."""
        iterate = iter_.read(iterate_index)
        new_vars = body(iterate, *loop_vars)

        loop_vars = (iterate_index + 1, )
        if new_vars:
            loop_vars += new_vars

        return loop_vars

    def while_cond(iterate_index, *loop_vars):
        if extra_test is not None:
            return control_flow_ops.cond(iterate_index < n,
                                         lambda: extra_test(*loop_vars),
                                         lambda: False)
        return iterate_index < n

    # TODO(b/134181679): Let the op itself handle optimizations.
    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
        opts['maximum_iterations'] = n

    results = _tf_while_stmt(
        while_cond,
        while_body,
        get_state,
        set_state,
        (array_ops.zeros_like(n), ) + init_vars,
        ('<internal iterate>', ) + basic_symbol_names,
        composite_symbol_names,
        opts,
    )

    # Note: the iteration index is not returned by the while loop, however
    # if a symbol with the same name exists outside the loop, it will be captured
    # by the loop variables and ultimately updated correctly.
    if isinstance(results, (tuple, list)):
        assert len(results) >= 1  # Has at least the iterate.
        if len(results) > 1:
            results = results[1:]
    else:
        results = ()

    return results
Exemplo n.º 9
0
    def test_len_dataset(self):
        dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
        self.assertEqual(self.evaluate(py_builtins.len_(dataset)), 3)

        # graph mode
        @def_function.function(autograph=False)
        def test_fn():
            dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
            return py_builtins.len_(dataset)

        self.assertEqual(self.evaluate(test_fn()), 3)
Exemplo n.º 10
0
    def test_len_dataset_infinite(self):
        dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
        with self.assertRaises(errors_impl.InvalidArgumentError):
            _ = self.evaluate(py_builtins.len_(dataset))

        # graph mode
        @def_function.function
        def test_fn():
            dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
            return py_builtins.len_(dataset)

        with self.assertRaises(errors_impl.InvalidArgumentError):
            self.evaluate(test_fn())
Exemplo n.º 11
0
  def test_len_dataset_unknown(self):
    dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
    with self.assertRaises(errors_impl.InvalidArgumentError):
      _ = self.evaluate(py_builtins.len_(dataset))

    # graph mode
    @def_function.function(autograph=False)
    def test_fn():
      dataset = dataset_ops.DatasetV2.range(5).filter(lambda _: True).batch(2)
      return py_builtins.len_(dataset)

    with self.assertRaises(errors_impl.InvalidArgumentError):
      self.evaluate(test_fn())
Exemplo n.º 12
0
def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
                           symbol_names, opts):
    """Overload of for_stmt that iterates over TF entities that admit a length."""
    n = py_builtins.len_(iter_)

    # TODO(b/117628877): Revisit performance once XLA has the necessary support.
    # Note: using a TensorArray creates an extra copy, but can calculate
    # gradients more efficiently than StridedSlice.
    ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
    iter_ = ta.unstack(iter_)

    iterate_index = compat_util.BasicRef(0)

    def aug_get_state():
        return (iterate_index.value, ) + get_state()

    def aug_set_state(aug_loop_vars):
        # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax.
        iterate_index.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:]
        # The iteration index is not "output" by the for loop. If the iterate
        # is used outside the loop, it will appear in the loop vars separately.
        set_state(loop_vars)

    def aug_body():
        body(iter_.read(iterate_index.value))
        iterate_index.value += 1

    def aug_test():
        main_test = iterate_index.value < n
        if extra_test is not None:
            return control_flow_ops.cond(main_test, extra_test, lambda: False)
        return main_test

    # TODO(b/159186914): Remove.
    if not control_flow_util.GraphOrParentsInXlaContext(
            ops.get_default_graph()):
        opts['maximum_iterations'] = n

    _tf_while_stmt(
        aug_test,
        aug_body,
        aug_get_state,
        aug_set_state,
        ('<internal iterate>', ) + symbol_names,
        opts,
    )
Exemplo n.º 13
0
def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
                           init_vars):
  """Overload of for_stmt that iterates over TF entities that admit a length."""
  _disallow_undefs_into_loop(*init_vars)

  n = py_builtins.len_(iter_)
  # TODO(b/117628877): Revisit performance once XLA has the necessary support.
  # Note: using a TensorArray creates an extra copy, but can calculate
  # gradients more efficiently than StridedSlice.
  ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
  iter_ = ta.unstack(iter_)

  def while_body(iterate_index, *loop_vars):
    iterate = iter_.read(iterate_index)
    new_vars = body(iterate, *loop_vars)

    loop_vars = (iterate_index + 1,)
    if new_vars:
      loop_vars += new_vars

    return loop_vars

  def while_cond(iterate_index, *loop_vars):
    if extra_test is not None:
      return control_flow_ops.cond(
          iterate_index < n, lambda: extra_test(*loop_vars), lambda: False)
    return iterate_index < n

  results = _tf_while_stmt(
      while_cond,
      while_body,
      get_state,
      set_state,
      init_vars=(0,) + init_vars,
      opts=dict(maximum_iterations=n))

  # Dropping the iteration index because it's not syntactically visible.
  # TODO(mdan): Don't.
  if isinstance(results, (tuple, list)):
    assert len(results) >= 1  # Has at least the iterate.
    if len(results) > 1:
      results = results[1:]
  else:
    results = ()

  return results
Exemplo n.º 14
0
def _known_len_tf_for_stmt(iter_, extra_test, body, init_state):
  """Overload of for_stmt that iterates over TF entities that admit a length."""
  _disallow_undefs_into_loop(*init_state)

  n = py_builtins.len_(iter_)
  # TODO(b/117628877): Revisit performance once XLA has the necessary support.
  # Note: using a TensorArray creates an extra copy, but can calculate
  # gradients more efficiently than StridedSlice.
  ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
  iter_ = ta.unstack(iter_)

  def while_body(iterate_index, *state):
    iterate = iter_.read(iterate_index)
    new_state = body(iterate, *state)

    state = (iterate_index + 1,)
    if new_state:
      state += new_state

    return state

  def while_cond(iterate_index, *state):
    if extra_test is not None:
      return control_flow_ops.cond(
          iterate_index < n,
          lambda: extra_test(*state),
          lambda: False)
    return iterate_index < n

  results = _tf_while_stmt(
      while_cond,
      while_body,
      init_state=(0,) + init_state,
      opts=dict(maximum_iterations=n))

  # Dropping the iteration index because it's not syntactically visible.
  # TODO(mdan): Don't.
  if isinstance(results, (tuple, list)):
    assert len(results) >= 1  # Has at least the iterate.
    if len(results) > 1:
      results = results[1:]
  else:
    results = ()

  return results
Exemplo n.º 15
0
def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
                           symbol_names, opts):
    """Overload of for_stmt that iterates over TF entities that admit a length."""
    n = py_builtins.len_(iter_)

    # TODO(b/117628877): Revisit performance once XLA has the necessary support.
    # Note: using a TensorArray creates an extra copy, but can calculate
    # gradients more efficiently than StridedSlice.
    ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
    iter_ = ta.unstack(iter_)

    iterate_index = 0

    def aug_get_state():
        return (iterate_index, ) + get_state()

    def aug_set_state(aug_loop_vars):
        nonlocal iterate_index
        # TODO(b/171479293): Drop the lint override.
        iterate_index, *loop_vars = aug_loop_vars  # pylint:disable=unused-variable
        # The iteration index is not "output" by the for loop. If the iterate
        # is used outside the loop, it will appear in the loop vars separately.
        set_state(loop_vars)

    def aug_body():
        nonlocal iterate_index
        body(iter_.read(iterate_index))
        iterate_index += 1

    def aug_test():
        main_test = iterate_index < n
        if extra_test is not None:
            return control_flow_ops.cond(main_test, extra_test, lambda: False)
        return main_test

    _add_max_iterations_hint(opts, n)

    _tf_while_stmt(
        aug_test,
        aug_body,
        aug_get_state,
        aug_set_state,
        ('<internal iterate>', ) + symbol_names,
        opts,
    )
Exemplo n.º 16
0
def _known_len_tf_for_stmt(iter_, extra_test, body, init_state):
  """Overload of for_stmt that iterates over objects that admit a length."""
  _disallow_undefs_into_loop(*init_state)

  n = py_builtins.len_(iter_)

  def while_body(iterate_index, *state):
    iterate = iter_[iterate_index]
    new_state = body(iterate, *state)

    state = (iterate_index + 1,)
    if new_state:
      state += new_state

    return state

  def while_cond(iterate_index, *state):
    if extra_test is not None:
      return control_flow_ops.cond(
          iterate_index < n,
          lambda: extra_test(*state),
          lambda: False)
    return iterate_index < n

  results = _tf_while_stmt(
      while_cond,
      while_body,
      init_state=(0,) + init_state,
      opts=dict(maximum_iterations=n))

  # Dropping the iteration index because it's not syntactically visible.
  # TODO(mdan): Don't.
  if isinstance(results, (tuple, list)):
    assert len(results) >= 1  # Has at least the iterate.
    if len(results) > 1:
      results = results[1:]
  else:
    results = ()

  return results
Exemplo n.º 17
0
 def test_len_scalar(self):
   with self.assertRaises(ValueError):
     py_builtins.len_(constant_op.constant(1))
Exemplo n.º 18
0
 def test_len_scalar(self):
     with self.assertRaises(ValueError):
         py_builtins.len_(constant_op.constant(1))
Exemplo n.º 19
0
 def test_fn():
     dataset = dataset_ops.DatasetV2.range(5).filter(
         lambda _: True).batch(2)
     return py_builtins.len_(dataset)
Exemplo n.º 20
0
 def test_fn():
     dataset = dataset_ops.DatasetV2.range(5).repeat().batch(2)
     return py_builtins.len_(dataset)
Exemplo n.º 21
0
 def test_fn():
     dataset = dataset_ops.DatasetV2.from_tensor_slices([3, 2, 1])
     return py_builtins.len_(dataset)