def maybe_adjust_terminator(self): """May change the last block's terminator instruction to a return. If the terminator meant "exit this control flow graph", change it to "return from this function". Raises: ValueError: If the terminator was a `BranchOp` that directly exited, because there is no "conditional indirect goto" instruction in the target IR. """ op = self.cur_block().terminator if inst.is_return_op(op): self.cur_block().terminator = inst.IndirectGotoOp() if (isinstance(op, inst.BranchOp) and (op.true_block is None or op.false_block is None)): # Why not? Because there is no "conditional indirect goto" # instruction in the target IR. One solution is to # canonicalize away directly exiting branches, by replacing # them with a branch to a fresh empty block that just exits. raise ValueError('Cannot lower exiting BranchOp {}.'.format(op))
def pea_nuts_program(latent_shape, choose_depth, step_state): """Synthetic program usable for benchmarking VM performance. This program is intended to resemble the control flow and scaling parameters of the NUTS algorithm, without any of the complexity. Hence the name. Each batch member looks like: state = ... # shape latent_shape def recur(depth, state): if depth > 1: state1 = recur(depth - 1, state) state2 = state1 + 1 state3 = recur(depth - 1, state2) ans = state3 + 1 else: ans = step_state(state) # To simulate NUTS, something heavy return ans while count > 0: count = count - 1 depth = choose_depth(count) state = recur(depth, state) Args: latent_shape: Python `tuple` of `int` giving the event shape of the latent state. choose_depth: Python `Tensor -> Tensor` callable. The input `Tensor` will have shape `[batch_size]` (i.e., scalar event shape), and give the iteration of the outer while loop the thread is in. The `choose_depth` function must return a `Tensor` of shape `[batch_size]` giving the depth, for each thread, to which to call `recur` in this iteration. step_state: Python `Tensor -> Tensor` callable. The input and output `Tensor`s will have shape `[batch_size] + latent_shape`. This function is expected to update the state, and represents the "real work" versus which the VM overhead is being measured. Returns: program: `instructions.Program` that runs the above benchmark. """ entry = instructions.Block() top_body = instructions.Block() finish_body = instructions.Block() enter_recur = instructions.Block() recur_body_1 = instructions.Block() recur_body_2 = instructions.Block() recur_body_3 = instructions.Block() recur_base_case = instructions.Block() # pylint: disable=bad-whitespace entry.assign_instructions([ instructions.prim_op(["count"], "cond", lambda count: count > 0), # cond = count > 0 instructions.BranchOp("cond", top_body, instructions.halt()), # if cond ]) top_body.assign_instructions([ instructions.PopOp(["cond"]), # done with cond now instructions.prim_op(["count"], "ctm1", lambda count: count - 1), # ctm1 = count - 1 instructions.PopOp(["count"]), # done with count now instructions.push_op(["ctm1"], ["count"]), # count = ctm1 instructions.PopOp(["ctm1"]), # done with ctm1 instructions.prim_op(["count"], "depth", choose_depth), # depth = choose_depth(count) instructions.push_op( ["depth", "state"], ["depth", "state"]), # state = recur(depth, state) instructions.PopOp(["depth", "state"]), # done with depth, state instructions.PushGotoOp(finish_body, enter_recur), ]) finish_body.assign_instructions([ instructions.push_op(["ans"], ["state"]), # ... instructions.PopOp(["ans"]), # pop callee's "ans" instructions.GotoOp(entry), # end of while body ]) # Definition of recur begins here enter_recur.assign_instructions([ instructions.prim_op(["depth"], "cond1", lambda depth: depth > 0), # cond1 = depth > 0 instructions.BranchOp("cond1", recur_body_1, recur_base_case), # if cond1 ]) recur_body_1.assign_instructions([ instructions.PopOp(["cond1"]), # done with cond1 now instructions.prim_op(["depth"], "dm1", lambda depth: depth - 1), # dm1 = depth - 1 instructions.PopOp(["depth"]), # done with depth instructions.push_op( ["dm1", "state"], ["depth", "state"]), # state1 = recur(dm1, state) instructions.PopOp(["state"]), # done with state instructions.PushGotoOp(recur_body_2, enter_recur), ]) recur_body_2.assign_instructions([ instructions.push_op(["ans"], ["state1"]), # ... instructions.PopOp(["ans"]), # pop callee's "ans" instructions.prim_op(["state1"], "state2", lambda state: state + 1), # state2 = state1 + 1 instructions.PopOp(["state1"]), # done with state1 instructions.push_op( ["dm1", "state2"], ["depth", "state"]), # state3 = recur(dm1, state2) instructions.PopOp(["dm1", "state2"]), # done with dm1, state2 instructions.PushGotoOp(recur_body_3, enter_recur), ]) recur_body_3.assign_instructions([ instructions.push_op(["ans"], ["state3"]), # ... instructions.PopOp(["ans"]), # pop callee's "ans" instructions.prim_op(["state3"], "ans", lambda state: state + 1), # ans = state3 + 1 instructions.PopOp(["state3"]), # done with state3 instructions.IndirectGotoOp(), # return ans ]) recur_base_case.assign_instructions([ instructions.PopOp(["cond1", "depth"]), # done with cond1, depth instructions.prim_op(["state"], "ans", step_state), # ans = step_state(state) instructions.PopOp(["state"]), # done with state instructions.IndirectGotoOp(), # return ans ]) pea_nuts_graph = instructions.ControlFlowGraph([ entry, top_body, finish_body, enter_recur, recur_body_1, recur_body_2, recur_body_3, recur_base_case, ]) # pylint: disable=bad-whitespace pea_nuts_vars = { "count": instructions.single_type(np.int64, ()), "cond": instructions.single_type(np.bool, ()), "cond1": instructions.single_type(np.bool, ()), "ctm1": instructions.single_type(np.int64, ()), "depth": instructions.single_type(np.int64, ()), "dm1": instructions.single_type(np.int64, ()), "state": instructions.single_type(np.float32, latent_shape), "state1": instructions.single_type(np.float32, latent_shape), "state2": instructions.single_type(np.float32, latent_shape), "state3": instructions.single_type(np.float32, latent_shape), "ans": instructions.single_type(np.float32, latent_shape), } return instructions.Program(pea_nuts_graph, [], pea_nuts_vars, ["count", "state"], "state")
def fibonacci_program(): """More complicated, fibonacci program: computes fib(n): fib(0) = fib(1) = 1. Returns: program: Full-powered `instructions.Program` that computes fib(n). """ entry = instructions.Block(name="entry") enter_fib = instructions.Block(name="enter_fib") recur1 = instructions.Block(name="recur1") recur2 = instructions.Block(name="recur2") recur3 = instructions.Block(name="recur3") finish = instructions.Block(name="finish") # pylint: disable=bad-whitespace entry.assign_instructions([ instructions.PushGotoOp(instructions.halt(), enter_fib), ]) # Definition of fibonacci function starts here enter_fib.assign_instructions([ instructions.prim_op(["n"], "cond", lambda n: n > 1), # cond = n > 1 instructions.BranchOp("cond", recur1, finish), # if cond ]) recur1.assign_instructions([ instructions.PopOp(["cond"]), # done with cond now instructions.prim_op(["n"], "nm1", lambda n: n - 1), # nm1 = n - 1 instructions.push_op(["nm1"], ["n"]), # fibm1 = fibonacci(nm1) instructions.PopOp(["nm1"]), # done with nm1 instructions.PushGotoOp(recur2, enter_fib), ]) recur2.assign_instructions([ instructions.push_op(["ans"], ["fibm1"]), # ... instructions.PopOp(["ans"]), # pop callee's "ans" instructions.prim_op(["n"], "nm2", lambda n: n - 2), # nm2 = n - 2 instructions.PopOp(["n"]), # done with n instructions.push_op(["nm2"], ["n"]), # fibm2 = fibonacci(nm2) instructions.PopOp(["nm2"]), # done with nm2 instructions.PushGotoOp(recur3, enter_fib), ]) recur3.assign_instructions([ instructions.push_op(["ans"], ["fibm2"]), # ... instructions.PopOp(["ans"]), # pop callee's "ans" instructions.prim_op(["fibm1", "fibm2"], "ans", lambda x, y: x + y), # ans = fibm1 + fibm2 instructions.PopOp(["fibm1", "fibm2"]), # done with fibm1, fibm2 instructions.IndirectGotoOp(), # return ans ]) finish.assign_instructions([ # else: instructions.PopOp(["n", "cond"]), # done with n, cond instructions.prim_op([], "ans", lambda: 1), # ans = 1 instructions.IndirectGotoOp(), # return ans ]) fibonacci_blocks = [entry, enter_fib, recur1, recur2, recur3, finish] # pylint: disable=bad-whitespace fibonacci_vars = { "n": instructions.single_type(np.int64, ()), "cond": instructions.single_type(np.bool, ()), "nm1": instructions.single_type(np.int64, ()), "fibm1": instructions.single_type(np.int64, ()), "nm2": instructions.single_type(np.int64, ()), "fibm2": instructions.single_type(np.int64, ()), "ans": instructions.single_type(np.int64, ()), } return instructions.Program( instructions.ControlFlowGraph(fibonacci_blocks), [], fibonacci_vars, ["n"], "ans")