예제 #1
0
def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
                     return_vars):

    return_var_ids = [id(var) for var in return_vars]
    # NOTE 1: Returned vars of Paddle op `control_flow.cond` must be Paddle Tensors
    # NOTE 2: Here uses id(var) not var, because `if var in return_var` use operator `==`,
    #  which will call `fluid.layers.equal` and causes error when var in return_vars is not initialized.
    true_args = [
        to_static_variable(var) if id(var) in return_var_ids else var
        for var in true_args
    ]
    false_args = [
        to_static_variable(var) if id(var) in return_var_ids else var
        for var in false_args
    ]

    pred = cast_bool_if_necessary(pred)
    return control_flow.cond(pred, lambda: true_fn(*true_args),
                             lambda: false_fn(*false_args))
예제 #2
0
def _run_paddle_while_loop(cond, body, loop_vars):
    # NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
    loop_vars = [to_static_variable(var) for var in loop_vars]
    loop_vars = control_flow.while_loop(cond, body, loop_vars)
    return loop_vars