def test_commingle_loop_vars(): def body(a, b): # b is a loop invariant return mb.add(x=a, y=b), b def cond(a, b): a_mean = mb.reduce_mean(x=a, axes=[0, 1]) b_mean = mb.reduce_mean(x=b, axes=[0, 1]) return mb.less(x=a_mean, y=b_mean) @mb.program(input_specs=[ mb.TensorSpec(shape=(1, 2)), mb.TensorSpec(shape=(1, 2)), ]) def prog(a, b): return mb.while_loop(_cond=cond, _body=body, loop_vars=(a, b)) while_op = prog.find_ops(op_type="while_loop", exactly_one=True)[0] assert while_op.blocks[0].inputs[0].name == "a.x" assert while_op.blocks[0].inputs[1].name == "b.x" prev_prog = copy.deepcopy(prog) PASS_REGISTRY["nn_backend::commingle_loop_vars"](prog) assert_same_output_names(prev_prog, prog) while_op = prog.find_ops(op_type="while_loop", exactly_one=True)[0] assert while_op.blocks[0].inputs[0].name == while_op.outputs[0].name assert while_op.blocks[0].inputs[1].name == while_op.outputs[1].name prog.validate()
def test_remove_vacuous_cond(): @mb.program(input_specs=[ mb.TensorSpec(shape=(1, ), dtype=types.bool), mb.TensorSpec(shape=(2, 3)), ]) def prog(a, b): def then_branch(): return mb.identity(x=b) def else_branch(): return mb.identity(x=b) pred = mb.squeeze(x=a) return mb.cond(pred=pred, _true_fn=then_branch, _false_fn=else_branch) cond_op = prog.find_ops(op_type="cond", exactly_one=True)[0] original_cond_op_name = cond_op.name assert len(cond_op.blocks[0].operations) == 1 assert len(cond_op.blocks[1].operations) == 1 assert cond_op.blocks[0].operations[0].op_type == "identity" assert cond_op.blocks[1].operations[0].op_type == "identity" prev_prog = copy.deepcopy(prog) PASS_REGISTRY["tensorflow2::remove_vacuous_cond"](prog) assert_same_output_names(prev_prog, prog) cond_op = prog.find_ops(op_type="cond") assert len(cond_op) == 0 identity_op = prog.find_ops(prefix=original_cond_op_name, exactly_one=True)[0] assert identity_op.op_type == "identity" if validate_model: assert_model_is_valid(prog, {"a": (1, ), "b": (2, 3)})
def test_const_elimination(): @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) def prog(x): a = np.random.rand(2, 4).astype(np.float32) double_a = mb.add(x=a, y=a) return mb.add(x=x, y=double_a) assert_op_count_match(prog, expect=2, op="const") prev_prog = copy.deepcopy(prog) PASS_REGISTRY["common::const_elimination"](prog) assert_same_output_names(prev_prog, prog) assert_op_count_match(prog, expect=3, op="const") if validate_model: assert_model_is_valid(prog, {"x": (2, 4)})
def test_loop_invariant_elimination2(): """ Invariant pattern: Block outputs var from outside of the block """ @mb.program(input_specs=[ mb.TensorSpec(shape=(1, 2)), mb.TensorSpec(shape=(1, 2)), ]) def prog(a, b): def body(a, bx): return mb.add(x=a, y=b), b def cond(a, bx): a_mean = mb.reduce_mean(x=a, axes=[0, 1]) b_mean = mb.reduce_mean(x=bx, axes=[0, 1]) return mb.less(x=a_mean, y=b_mean) # b is loop invariant return mb.while_loop(_cond=cond, _body=body, loop_vars=(a, b)) while_op = prog.find_ops(op_type="while_loop", exactly_one=True)[0] if len(while_op.blocks[0].inputs) != 2: raise AssertionError if len(while_op.outputs) != 2: raise AssertionError if len(while_op.loop_vars) != 2: raise AssertionError if while_op.blocks[0].inputs[0].name != "a.x": raise AssertionError if while_op.blocks[0].inputs[1].name != "b.x": raise AssertionError prev_prog = copy.deepcopy(prog) PASS_REGISTRY["common::loop_invariant_elimination"](prog) assert_same_output_names(prev_prog, prog) while_op = prog.find_ops(op_type="while_loop", exactly_one=True)[0] if len(while_op.blocks[0].inputs) != 1: raise AssertionError if len(while_op.outputs) != 1: raise AssertionError if len(while_op.loop_vars) != 1: raise AssertionError if while_op.blocks[0].inputs[0].name != "a.x": raise AssertionError if validate_model: assert_model_is_valid(prog, {"a": (1, 2), "b": (1, 2)})
def test_handle_unused_inputs(): @mb.program(input_specs=[ mb.TensorSpec(shape=(1, 2)), ]) def prog(unused_input): return mb.const(val=[3, 2]) prev_prog = copy.deepcopy(prog) PASS_REGISTRY["nn_backend::handle_unused_inputs"](prog) assert_same_output_names(prev_prog, prog) id_op = prog.find_ops(op_type="identity", exactly_one=True)[0] # Assert that input var is consumed by an identity op. assert id_op in prog["main"].inputs["unused_input"].child_ops assert_model_is_valid(prog, {"unused_input": (1, 2)})
def test_backfill_make_list_elem_type(): # The while_loop appends [1, 2]*i to `ls` for each iteration # i = 0, ... num_iters-1. elem_shape = (2, ) @mb.program(input_specs=[ mb.TensorSpec(shape=elem_shape), ]) def prog(update): def body(i, ls): return mb.add(x=i, y=1), mb.list_write(ls=ls, index=i, value=update) def cond(i, ls): return mb.less(x=i, y=num_iters) i = 0 ls = mb.tf_make_list(init_length=1) num_iters = 3 _, final_tensor_list = mb.while_loop(_cond=cond, _body=body, loop_vars=(i, ls)) list_len = mb.list_length(ls=final_tensor_list) indices = mb.range_1d(start=0, end=list_len, step=1) return mb.list_gather(ls=final_tensor_list, indices=indices) # tf_make_list has no elem_type info make_list_op = prog.find_ops(op_type="tf_make_list", exactly_one=True)[0] if make_list_op.outputs[0].elem_type != types.unknown: raise AssertionError prev_prog = copy.deepcopy(prog) PASS_REGISTRY["tensorflow::backfill_make_list_elem_type"](prog) assert_same_output_names(prev_prog, prog) prog.validate() # tf_make_list is replaced with make_list and should have elem_type now make_list_op = prog.find_ops(op_type="make_list", exactly_one=True)[0] if make_list_op.outputs[0].elem_type.get_shape() != elem_shape: raise AssertionError assert_model_is_valid(prog, {"update": elem_shape})
def test_handle_return_return_inputs_as_outputs(): @mb.program(input_specs=[ mb.TensorSpec(shape=(1, 2)), mb.TensorSpec(shape=(1, 2)), ]) def prog(a, b): return mb.mul(x=a, y=2), b prev_main_output_names = [o.name for o in prog["main"].outputs] assert prog["main"].outputs[1].op is None # output comes from input prev_prog = copy.deepcopy(prog) PASS_REGISTRY["nn_backend::handle_return_inputs_as_outputs"](prog) assert_same_output_names(prev_prog, prog) assert prog["main"].outputs[1].op is not None # output comes from an op assert prog["main"].outputs[1].op.op_type == "identity" assert_model_is_valid(prog, {"a": (1, 2), "b": (1, 2)})
def test_dead_code_elimination(): @mb.program(input_specs=[ mb.TensorSpec(shape=(2, 4)), mb.TensorSpec(shape=(2, 4)), ]) def program0(x, y): # following three unused op should be eliminated a = mb.const(val=np.zeros(shape=(1, )), mode="immediate_value") b = mb.const(val=np.zeros(shape=(1, )), mode="immediate_value") _ = mb.add(x=a, y=b) return mb.add(x=x, y=y) assert_op_count_match(program0, expect=4) prev_prog = copy.deepcopy(program0) PASS_REGISTRY["common::dead_code_elimination"](program0) assert_same_output_names(prev_prog, program0) assert_op_count_match(program0, expect=1) if validate_model: assert_model_is_valid(program0, {"x": (2, 4), "y": (2, 4)}) @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) def program1(x): weights_val = np.random.rand(2, 4).T.astype(np.float32) weights = mb.const(val=weights_val, mode="immediate_value") bias_val = np.random.rand(4).astype(np.float32) bias = mb.const(val=bias_val, mode="immediate_value") # unused op and its inputs should be eliminated mb.matmul(x=x, y=weights) return mb.linear(x=x, weight=weights, bias=bias) assert_op_count_match(program1, expect=6) prev_prog = copy.deepcopy(program1) PASS_REGISTRY["common::dead_code_elimination"](program1) assert_same_output_names(prev_prog, program1) assert_op_count_match(program1, expect=3) if validate_model: assert_model_is_valid(program1, {"x": (2, 4)})
def test_fuse_matmul_weight_bias(): @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) def prog(x): weights_val = np.random.rand(2, 4).T.astype(np.float32) weights = mb.const(val=weights_val, mode="immediate_value") bias_val = np.random.rand(2).astype(np.float32) bias = mb.const(val=bias_val, mode="immediate_value") matmul = mb.matmul(x=x, y=weights) return mb.add(x=matmul, y=bias) assert_op_count_match(prog, expect=1, op="matmul") assert_op_count_match(prog, expect=0, op="linear") prev_prog = copy.deepcopy(prog) PASS_REGISTRY["common::fuse_matmul_weight_bias"](prog) assert_same_output_names(prev_prog, prog) assert_op_count_match(prog, expect=0, op="matmul") assert_op_count_match(prog, expect=1, op="linear") if validate_model: assert_model_is_valid(prog, {"x": (2, 4)})
def test_loop_invariant_elimination1(): """ Invariant pattern: Block input vars are returned as block output vars. """ def body(a, b): return mb.add(x=a, y=b), b def cond(a, b): a_mean = mb.reduce_mean(x=a, axes=[0, 1]) b_mean = mb.reduce_mean(x=b, axes=[0, 1]) return mb.less(x=a_mean, y=b_mean) @mb.program(input_specs=[ mb.TensorSpec(shape=(1, 2)), mb.TensorSpec(shape=(1, 2)), ]) def prog(a, b): # b is loop invariant return mb.while_loop(_cond=cond, _body=body, loop_vars=(a, b)) while_op = prog.find_ops(op_type="while_loop", exactly_one=True)[0] assert len(while_op.blocks[0].inputs) == 2 assert len(while_op.outputs) == 2 assert len(while_op.loop_vars) == 2 assert while_op.blocks[0].inputs[0].name == "a.x" assert while_op.blocks[0].inputs[1].name == "b.x" prev_prog = copy.deepcopy(prog) PASS_REGISTRY["common::loop_invariant_elimination"](prog) assert_same_output_names(prev_prog, prog) while_op = prog.find_ops(op_type="while_loop", exactly_one=True)[0] assert len(while_op.blocks[0].inputs) == 1 assert len(while_op.outputs) == 1 assert len(while_op.loop_vars) == 1 assert while_op.blocks[0].inputs[0].name == "a.x" if validate_model: assert_model_is_valid(prog, {"a": (1, 2), "b": (1, 2)})
def test_divide_to_multiply(): @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) def prog(x): div_val = np.random.rand(2, 4).astype(np.float32) div_const = mb.const(val=div_val, mode="immediate_value") div_val_1 = np.random.rand(2, 4).astype(np.float32) div_const_1 = mb.const(val=div_val_1, mode="immediate_value") real_div = mb.real_div(x=x, y=div_const) return mb.real_div(x=real_div, y=div_const_1) assert_op_count_match(prog, expect=2, op="real_div") assert_op_count_match(prog, expect=0, op="mul") prev_prog = copy.deepcopy(prog) PASS_REGISTRY["common::divide_to_multiply"](prog) assert_same_output_names(prev_prog, prog) assert_op_count_match(prog, expect=0, op="real_div") assert_op_count_match(prog, expect=2, op="mul") if validate_model: assert_model_is_valid(prog, {"x": (2, 4)})
def test_handle_return_inputs_as_outputs(): @mb.program(input_specs=[ mb.TensorSpec(shape=(1, 2)), mb.TensorSpec(shape=(1, 2)), ]) def prog(a, b): return mb.mul(x=a, y=2), b prev_main_output_names = [o.name for o in prog["main"].outputs] assert prog["main"].outputs[1].op is None # output comes from input prev_prog = copy.deepcopy(prog) PASS_REGISTRY["nn_backend::handle_return_inputs_as_outputs"](prog) assert_same_output_names(prev_prog, prog) assert prog["main"].outputs[1].op is not None # output comes from an op assert prog["main"].outputs[1].op.op_type == "identity" with pytest.raises(ValueError, match='used both as function\'s input and output'): # prog has input and output names 'b' that refer to different vars # This program can pass if we disable 'dedup_op_and_var_names' pass assert_model_is_valid(prog, {"a": (1, 2), "b": (1, 2)})
def test_op_removal_and_insertion(): """ Remove a transpose pair and materialize one transpose before another op Given: %x1 = transpose(%x) %x2 = relu(%x1) %out1 = avg_pool(%x2) %x3 = transpose(%x2) %out2 = log(%x3) After removing both transposes: %x2 = relu(%x) %out1 = avg_pool(%x2) %out2 = log(%x2) After inserting a transpose: %x2 = relu(%x) %x4 = transpose(%x2) %out1 = avg_pool(%x4) %out2 = log(%x2) """ @mb.program(input_specs=[mb.TensorSpec(shape=(1, 2, 6, 6))]) def prog(x): x1 = mb.transpose(x=x, perm=[0, 2, 3, 1]) x2 = mb.relu(x=x1) out1 = mb.avg_pool(x=x2, kernel_sizes=[1, 1], strides=[1, 1], pad_type="valid") x3 = mb.transpose(x=x2, perm=[0, 3, 1, 2]) out2 = mb.log(x=x3) return out1, out2 prev_prog = copy.deepcopy(prog) print("before:\n{}".format(prog)) assert get_op_types_in_program(prog) == [ "transpose", "relu", "avg_pool", "transpose", "log", ] block = prog.functions["main"] def remove_transpose(block): op = block.find_ops(op_type="transpose")[0] block.replace_uses_of_var_after_op( anchor_op=op.inputs["x"].op, old_var=op.outputs[0], new_var=op.inputs["x"], no_check_var_types=True, ) block.remove_ops([op]) # remove 1st transpose remove_transpose(block) assert get_op_types_in_program(prog) == [ "relu", "avg_pool", "transpose", "log" ] # remove 2nd transpose remove_transpose(block) assert get_op_types_in_program(prog) == ["relu", "avg_pool", "log"] print("after transpose ops removal:\n{}".format(prog)) # insert transpose before pool pool_op = block.find_ops(op_type="avg_pool")[0] with block: y = mb.transpose(x=pool_op.inputs["x"], perm=[0, 2, 3, 1], before_op=pool_op) block.replace_uses_of_var_after_op( anchor_op=y.op, end_op=pool_op, old_var=pool_op.inputs["x"], new_var=y, no_check_var_types=True, ) print("after transpose insertion:\n{}".format(prog)) assert get_op_types_in_program(prog) == [ "relu", "transpose", "avg_pool", "log" ] for op in block.operations: op.type_value_inference(overwrite_output=True) assert_same_output_names(prev_prog, prog) assert_same_output_shapes(prev_prog, prog)