def pattern_add(x): """ Original: % 4 = linear(x= % 1, weight = % 2, bias = % 3) # %2 is a rank-2 const tensor (weight) # %3 is a rank-1 const tensor (bias) ... % 6 = add(x= % 4, y = % 5) # %5 is a const tensor with same shape as %3 Result: % 8 = linear(x= % 1, weight = % 2, bias = % 7) # where %7 is a new const tensor with value # %7 = %3 + %6 """ linear = mb.linear(x=x, weight=arbitrary_weight, bias=arbitrary_bias, name="linear") add_or_sub = mb.add(x=linear, y=arbitrary_bias, name="add_or_sub") return add_or_sub
def pattern_sub(x): """ Original: %4 = linear(x=%1, weight=%2, bias=%3) # %2 is a rank-2 const tensor (weight) # %3 is a rank-1 const tensor (bias) ... %6 = sub(x=%5, y=%4) # %5 is a const tensor with a broacasable shape with %3. i.e. if %3 has shape (Dout), %5 could be (1, Dout). Result: %9 = linear(x=%1, weight=%7, bias=%8) # where %7 is a new const tensor with value %7 = -%2 # %8 = %5 - %3 """ linear = mb.linear(x=x, weight=arbitrary_weight, bias=arbitrary_bias, name="linear") add_or_sub = mb.sub(x=linear, y=arbitrary_bias, name="add_or_sub") return add_or_sub
def transform_pattern(pattern): is_sub, is_first_input = _get_is_sub_and_is_first_input(pattern) linear_bias, bias, Dout = _get_linear_bias_bias_Dout( pattern, is_first_input) bias = np.reshape(bias, (Dout, )) if is_sub and is_first_input: bias = -bias if is_sub and not is_first_input: linear_bias = -linear_bias new_bias = linear_bias + bias # compute the new weight if is_sub and not is_first_input: new_weight = -pattern.linear.weight.val else: new_weight = pattern.linear.weight.val # create a new linear op with the new weight, bias value, copying rest of the attributes out_name = pattern.add_or_sub.outputs[0].name linear_kargs = { "weight": new_weight, "bias": new_bias, "name": out_name, "before_op": pattern.linear } linear_kargs.update({ k: v for k, v in pattern.linear.inputs.items() if k not in ["weight", "bias"] }) x = mb.linear(**linear_kargs) pattern.add_or_sub.enclosing_block.replace_uses_of_var_after_op( anchor_op=pattern.add_or_sub, old_var=pattern.add_or_sub.outputs[0], new_var=x) # Remove all the ops at once pattern.block.remove_ops(pattern.op_list())
def linear_prog(input): W = mb.const(val=np.random.rand(100, 5000), name="const_W") out = mb.linear(x=input, weight=W, name="output") return out
def linear_prog(input): W = np.ones((10, 3), dtype=np.float) out = mb.linear(x=input, weight=W, name="output") return out