示例#1
0
def for_stmt(target, iter_, body, orelse, local_writes):
    """Functional form of a for statement."""

    if tf.is_tensor(iter_):
        local_writes = [
            var for var in local_writes
            if not py_defaults.is_undefined(var.val)
        ]

        n = _tf_len(iter_)

        def for_test(i, *_):
            return i < n

        def for_body(iterate_index, *state):  # pylint: disable=missing-docstring
            for var, s in zip(local_writes, state):
                var.val = s
            target.val = iter_[iterate_index]
            mods, _ = staging.execute_isolated(body, local_writes)

            state = [iterate_index + 1] + mods
            return state

        result_values = _tf_while_stmt(for_test, for_body,
                                       [0] + [var.val for var in local_writes])

        for var, val in zip(local_writes, result_values[1:]):
            var.val = val
    else:
        py_defaults.for_stmt(target, iter_, body, orelse, local_writes)
示例#2
0
def while_stmt(cond, body, orelse, local_writes):
    """Functional form of a while statement."""

    cond_result = cond()

    if tf.is_tensor(cond_result):
        local_writes = [
            var for var in local_writes
            if not py_defaults.is_undefined(var.val)
        ]

        def while_test(*state):
            for var, s in zip(local_writes, state):
                var.val = s
            _, retvals = staging.execute_isolated(cond, local_writes)
            return retvals

        def while_body(*state):
            for var, s in zip(local_writes, state):
                var.val = s
            mods, _ = staging.execute_isolated(body, local_writes)
            return mods

        result_values = _tf_while_stmt(while_test, while_body,
                                       [var.val for var in local_writes])

        for var, val in zip(local_writes, result_values):
            var.val = val
    else:
        staging.run_python_while(cond, body, orelse, cond_result)
示例#3
0
def while_stmt(cond, body, _, local_writes):
    """Functional form of a while statement."""

    local_writes = [
        var for var in local_writes if not py_defaults.is_undefined(var.val)
    ]

    def while_test(state):
        for var, s in zip(local_writes, state):
            var.val = s
        _, result_values = staging.execute_isolated(cond, local_writes)
        return result_values

    def while_body(state):
        for var, s in zip(local_writes, state):
            var.val = s
        modified_vals, _ = staging.execute_isolated(body, local_writes)
        return modified_vals

    result_values = lax.while_loop(while_test, while_body,
                                   [var.val for var in local_writes])

    for var, val in zip(local_writes, result_values):
        var.val = val

    return result_values
示例#4
0
    def protected_func():
        """Calls function and raises an error if undefined symbols are returned."""
        results = func()
        undefined_symbols = None
        if isinstance(results, tuple):
            undefined_symbols = _filter_undefined(results)
        elif py_defaults.is_undefined(results):
            # Single return value
            undefined_symbols = results.symbol_name

        if undefined_symbols:
            message = (
                'The following symbols must also be initialized in the %s '
                'branch: {}. Alternatively, you may initialize them before '
                'the if statement.') % branch_name
            message = message.format(undefined_symbols)
            raise ValueError(message)
        return results
示例#5
0
def for_stmt(target, iter_, body, orelse, modified_vars):
    """Functional form of a for statement."""
    del orelse

    modified_vars = [
        var for var in modified_vars if not py_defaults.is_undefined(var.val)
    ]

    def for_body(idx, state):
        for var, s in zip(modified_vars, state):
            var.val = s
        target.val = iter_[idx]
        modified_vals, _ = staging.execute_isolated(body, modified_vars)
        return modified_vals

    results = lax.fori_loop(0, len(iter_), for_body,
                            [var.val for var in modified_vars])
    for var, val in zip(modified_vars, results):
        var.val = val
示例#6
0
def _filter_undefined(all_symbols):
    """Returns the names of undefined symbols contained in all_symbols."""
    undefined_symbols = [
        s.name for s in all_symbols if py_defaults.is_undefined(s)
    ]
    return undefined_symbols