def slice_tensor_array(array, start, end): def true_fn(): null_array = create_array("float32") return null_array def false_fn(array, start, end): new_array = slice(array, starts=[start], ends=[end], axes=[0]) return new_array new_array = cond(start == end, true_fn, lambda: false_fn(array, start, end)) return new_array
def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args, return_vars): return_var_ids = [id(var) for var in return_vars] # NOTE 1: Returned vars of Paddle op `control_flow.cond` must be Paddle Tensors # NOTE 2: Here uses id(var) not var, because `if var in return_var` use operator `==`, # which will call `fluid.layers.equal` and causes error when var in return_vars is not initialized. true_args = [ to_static_variable(var) if id(var) in return_var_ids else var for var in true_args ] false_args = [ to_static_variable(var) if id(var) in return_var_ids else var for var in false_args ] pred = cast_bool_if_necessary(pred) return control_flow.cond(pred, lambda: true_fn(*true_args), lambda: false_fn(*false_args))
def convert_while_loop(cond, body, loop_vars): """ A function representation of a Python ``while`` statement. Args: cond(Callable): A callable object that returns a boolean variable to control whether to execute the loop body. It takes ``loop_vars`` as arguments. body(Callable): A callable object that returns a tuple or list of variables with the same arguments ``loops_vars`` as ``cond`` . loop_vars(list|tuple): A list or tuple of variables passed to ``cond`` and ``body`` . Returns: A list or tuple of variables which returned by ``body``. """ # NOTE: It may be slower if cond is very expensive, but usually cond is just O(1). # If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars. pred = cond(*loop_vars) if isinstance(pred, Variable): loop_vars = _run_paddle_while_loop(cond, body, loop_vars) else: loop_vars = _run_py_while(cond, body, loop_vars) return loop_vars
def _run_py_while(cond, body, loop_vars): while cond(*loop_vars): loop_vars = body(*loop_vars) return loop_vars
def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args, return_vars): pred = cast_bool_if_necessary(pred) return control_flow.cond(pred, lambda: true_fn(*true_args), lambda: false_fn(*false_args))