Beispiel #1
0
def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
                       symbol_names, opts):
    """Overload of for_stmt that iterates over a TF range (and elides it)."""
    start, limit, delta = iter_.op.inputs

    iterate = compat_util.BasicRef(start)

    def aug_get_state():
        return (iterate.value, ) + get_state()

    def aug_set_state(aug_loop_vars):
        # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax.
        iterate.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:]
        # The iteration index is not "output" by the for loop. If the iterate
        # is used outside the loop, it will appear in the loop vars separately.
        set_state(loop_vars)

    def aug_body():
        body(iterate.value)
        iterate.value += delta

    def aug_test():
        main_test = math_ops.logical_or(
            math_ops.logical_and(delta >= 0, iterate.value < limit),
            math_ops.logical_and(delta < 0, iterate.value > limit))
        if extra_test is not None:
            return control_flow_ops.cond(main_test, extra_test, lambda: False)
        return main_test

    opts['maximum_iterations'] = math_ops.cast(
        misc.get_range_len(start, limit, delta), dtypes.int32)

    _tf_while_stmt(aug_test, aug_body, aug_get_state, aug_set_state,
                   ('<internal iterate>', ) + symbol_names, opts)
Beispiel #2
0
def _tf_range_for_stmt(
    iter_, extra_test, body, get_state, set_state, symbol_names, opts):
  """Overload of for_stmt that iterates over a TF range (and elides it)."""
  start, limit, delta = iter_.op.inputs

  iterate = start

  def _value_or(name, var, default):
    if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)):
      return default
    return var

  def aug_get_state():
    state_vars = get_state()
    state_vars = tuple(
        _value_or(name, var, iterate)
        for name, var in zip(symbol_names, state_vars))
    return (iterate,) + state_vars

  def aug_set_state(aug_loop_vars):
    nonlocal iterate
    # TODO(b/171479293): Drop the lint override.
    iterate, *loop_vars = aug_loop_vars  # pylint:disable=unused-variable
    # The iteration index is not "output" by the for loop. If the iterate
    # is used outside the loop, it will appear in the loop vars separately.
    set_state(loop_vars)

  def aug_body():
    nonlocal iterate
    body(iterate)
    iterate += delta

  def aug_test():
    # TODO(b/159713842): Remove once constant folding works.
    const_delta = tensor_util.constant_value(delta)
    if const_delta is not None:
      if const_delta >= 0:
        main_test = iterate < limit
      else:
        main_test = iterate > limit
    else:
      main_test = math_ops.logical_or(
          math_ops.logical_and(delta >= 0, iterate < limit),
          math_ops.logical_and(delta < 0, iterate > limit))

    if extra_test is not None:
      main_test = control_flow_ops.cond(main_test, extra_test, lambda: False)
    return main_test

  _add_max_iterations_hint(
      opts,
      math_ops.cast(misc.get_range_len(start, limit, delta), dtypes.int32))

  _tf_while_stmt(
      aug_test,
      aug_body,
      aug_get_state,
      aug_set_state,
      ('<internal iterate>',) + symbol_names,
      opts)
Beispiel #3
0
def _tf_range_for_stmt(
    iter_, extra_test, body, get_state, set_state, symbol_names, opts):
  """Overload of for_stmt that iterates over a TF range (and elides it)."""
  start, limit, delta = iter_.op.inputs

  iterate = compat_util.BasicRef(start)

  def _value_or(name, var, default):
    if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)):
      return default
    return var

  def aug_get_state():
    state_vars = get_state()
    state_vars = tuple(
        _value_or(name, var, iterate.value)
        for name, var in zip(symbol_names, state_vars))
    return (iterate.value,) + state_vars

  def aug_set_state(aug_loop_vars):
    # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax.
    iterate.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:]
    # The iteration index is not "output" by the for loop. If the iterate
    # is used outside the loop, it will appear in the loop vars separately.
    set_state(loop_vars)

  def aug_body():
    body(iterate.value)
    iterate.value += delta

  def aug_test():
    # TODO(b/159713842): Remove once constant folding works.
    const_delta = tensor_util.constant_value(delta)
    if const_delta is not None:
      if const_delta >= 0:
        main_test = iterate.value < limit
      else:
        main_test = iterate.value > limit
    else:
      main_test = math_ops.logical_or(
          math_ops.logical_and(delta >= 0, iterate.value < limit),
          math_ops.logical_and(delta < 0, iterate.value > limit))

    if extra_test is not None:
      main_test = control_flow_ops.cond(main_test, extra_test, lambda: False)
    return main_test

  # TODO(b/134181679): Remove.
  if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
    opts['maximum_iterations'] = math_ops.cast(
        misc.get_range_len(start, limit, delta), dtypes.int32)

  _tf_while_stmt(
      aug_test,
      aug_body,
      aug_get_state,
      aug_set_state,
      ('<internal iterate>',) + symbol_names,
      opts)
def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
                       init_vars, basic_symbol_names, composite_symbol_names,
                       opts):
    """Overload of for_stmt that iterates over a TF range (and elides it)."""
    _disallow_undefs_into_loop(*init_vars)

    start, limit, delta = iter_.op.inputs

    def while_body(iterate, *loop_vars):
        new_vars = body(iterate, *loop_vars)
        loop_vars = (iterate + delta, )

        if new_vars:
            loop_vars += new_vars

        return loop_vars

    def while_cond(iterate, *loop_vars):
        """Cond function for `tf.while_loop`."""
        main_test = math_ops.logical_or(
            math_ops.logical_and(delta >= 0, iterate < limit),
            math_ops.logical_and(delta < 0, iterate > limit))
        if extra_test is not None:
            return control_flow_ops.cond(
                main_test,
                lambda: extra_test(*loop_vars),
                lambda: False,
            )
        return main_test

    opts['maximum_iterations'] = math_ops.cast(
        misc.get_range_len(start, limit, delta), dtypes.int32)

    results = _tf_while_stmt(
        while_cond,
        while_body,
        get_state,
        set_state,
        (start, ) + init_vars,
        ('<internal iterate>', ) + basic_symbol_names,
        composite_symbol_names,
        opts,
    )

    # Note: the iteration index is not returned by the while loop, however
    # if a symbol with the same name exists outside the loop, it will be captured
    # by the loop variables and ultimately updated correctly.
    if isinstance(results, (tuple, list)):
        assert len(results) >= 1  # Has at least the iterate.
        if len(results) > 1:
            results = results[1:]
    else:
        results = ()

    return results
Beispiel #5
0
def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
                       init_vars, basic_symbol_names, composite_symbol_names,
                       opts):
    """Overload of for_stmt that iterates over a TF range (and elides it)."""
    _disallow_undefs_into_loop(*init_vars)

    start, limit, delta = iter_.op.inputs

    def while_body(iterate, *loop_vars):
        new_vars = body(iterate, *loop_vars)
        loop_vars = (iterate + delta, )

        if new_vars:
            loop_vars += new_vars

        return loop_vars

    def while_cond(iterate, *loop_vars):
        """Cond function for `tf.while_loop`."""
        def build_main_test():
            """Main iteration condition."""
            # TODO(b/138857806): The optimizer should handle this.
            # LogicalAnd is slow on GPU so we avoid adding it if `delta` is a
            # compile time constant.
            delta_const = tensor_util.constant_value(delta)
            if delta_const is not None:
                # Support single element arrays.
                delta_const = np.asscalar(delta_const)
                if delta_const >= 0:
                    return iterate < limit
                else:
                    return iterate > limit
            else:
                return math_ops.logical_or(
                    math_ops.logical_and(delta >= 0, iterate < limit),
                    math_ops.logical_and(delta < 0, iterate > limit))

        main_test = build_main_test()
        if extra_test is not None:
            return control_flow_ops.cond(
                main_test,
                lambda: extra_test(*loop_vars),
                lambda: False,
            )
        return main_test

    # TODO(b/134181679): The op should handle this optimizations.
    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
        # This specific dtype is required by while_loop.
        opts['maximum_iterations'] = math_ops.cast(
            misc.get_range_len(start, limit, delta), dtypes.int32)

    results = _tf_while_stmt(
        while_cond,
        while_body,
        get_state,
        set_state,
        (start, ) + init_vars,
        ('<internal iterate>', ) + basic_symbol_names,
        composite_symbol_names,
        opts,
    )

    # Note: the iteration index is not returned by the while loop, however
    # if a symbol with the same name exists outside the loop, it will be captured
    # by the loop variables and ultimately updated correctly.
    if isinstance(results, (tuple, list)):
        assert len(results) >= 1  # Has at least the iterate.
        if len(results) > 1:
            results = results[1:]
    else:
        results = ()

    return results