Example #1
0
def arange_inference_rule(n: Node, symbols, constraints, counter):
    start = 0
    step = 1

    if len(n.args) == 1:
        end = symbols[n.args[0]]
    else:
        raise NotImplementedError('Not yet implemented')

    # int((end - start) / step)
    d1, counter = gen_dvar(counter)
    size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq)
    arange, counter = gen_tvar(counter)
    symbols[n] = arange

    # either the a parameter is a number or it is Dyn
    c1 = Disj([BinConstraintD(end, Dyn, op_eq),
               BinConstraintD(start, Dyn, op_eq),
               BinConstraintD(step, Dyn, op_eq)])
    c2 = BinConstraintD(d1, Dyn, op_eq)
    both_dyn = Conj([c1, c2])

    c11 = Conj([BinConstraintD(end, Dyn, op_neq),
                BinConstraintD(start, Dyn, op_neq),
                BinConstraintD(step, Dyn, op_neq)])
    c22 = BinConstraintD(d1, Dyn, op_neq)
    both_numbers = Conj([c11, c22, size_constraint])

    return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
Example #2
0
def index_select_inference_rule(n: Node, symbols, constraints, counter):
    """
    We constrain the second argument to a vector or Dyn.
    The output replaces the input with the shape of the vector
    at the position given by the index (first argument)
    """
    # print(n.args)
    assert isinstance(n.args[0], Node)
    assert isinstance(n.args[1], int)
    assert isinstance(n.args[2], Node)



    index_select, counter = gen_tvar(counter)
    symbols[n] = index_select

    dims, counter = gen_tensor_dims(1, counter)

    # equality constraint
    is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
    is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)

    c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select)
                                for i in range(MAX_TENSOR_RANK)])])
    c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
                             for i in range(MAX_TENSOR_RANK)])])

    return [Disj([c2, c3])], counter
Example #3
0
def cumsum_inference_rule(n: Node, symbols, constraints, counter):
    """
    Input and output shapes should be equal
    We should verify that the index is valid
    """
    assert isinstance(n.args[0], Node)
    arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"]
    assert isinstance(arg_1, int)

    output, counter = gen_tvar(counter)
    symbols[n] = output
    input = symbols[n.args[0]]

    input_dyn = BinConstraintT(input, Dyn, op_eq)
    output_dyn = BinConstraintT(output, Dyn, op_eq)
    c1 = Conj([input_dyn, output_dyn])
    c2 = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims, counter = gen_tensor_dims(i, counter)

        nat_constraints = gen_nat_constraints(new_dims)

        c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq),
                           BinConstraintT(output, TensorType(new_dims), op_eq)] +
                          [range_check(arg_1, i)] + nat_constraints)

        c2.append(c_tensor_i)
    dyn_or_tensor = Disj([c1, Disj(c2)])
    return [dyn_or_tensor], counter
Example #4
0
def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter):
    """
    Input and output sizes should be the same except for the last dimension
    If the input is Dyn, then so should the output
    """
    assert isinstance(n.args[0], Node)
    linear_output, counter = gen_tvar(counter)
    symbols[n] = linear_output
    linear_input = symbols[n.args[0]]

    input_dyn = BinConstraintT(linear_input, Dyn, op_eq)
    output_dyn = BinConstraintT(linear_output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])

    c2 = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
        new_dims_rhs_2, counter = gen_tensor_dims(i, counter)

        nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)

        c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
                           BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] +
                          add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, module_instance) +
                          nat_constraints)
        c2.append(c_tensor_i)


    return [Disj([c1, Disj(c2)])], counter
Example #5
0
def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):

    embedding_output, counter = gen_tvar(counter)
    symbols[n] = embedding_output
    embedding_input = symbols[n.args[0]]

    input_dyn = BinConstraintT(embedding_input, Dyn, op_eq)
    output_dyn = BinConstraintT(embedding_output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])
    c2 = []

    for i in range(1, MAX_TENSOR_RANK):
        new_dims, counter = gen_tensor_dims(i, counter)
        nat_constraints = gen_nat_constraints(new_dims)

        # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
        c_tensor_i = Conj([
            BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
            BinConstraintT(embedding_output,
                           TensorType(new_dims + [embedding_dim]), op_eq)
        ] + nat_constraints)
        c2.append(c_tensor_i)

    return [Disj([c1, Disj(c2)])], counter
Example #6
0
def transpose_inference_rule(n: Node, symbols, constraints, counter):
    """
    Can be considered as a sequence of two index selects, so we generate constraints accordingly
    """
    assert isinstance(n.args[0], Node)
    assert isinstance(n.args[1], int)
    assert isinstance(n.args[2], int)

    output, counter = gen_tvar(counter)
    symbols[n] = output

    from_arg = symbols[n.args[0]]
    assert isinstance(from_arg, TVar)

    # input and output are dyn
    is_dyn = Conj([
        BinConstraintT(from_arg, Dyn, op_eq),
        BinConstraintT(output, Dyn, op_eq)
    ])

    # or input is a tensor and we actually do the replacement
    c3 = Disj([
        Transpose(i + 1, from_arg, n.args[1], n.args[2], output)
        for i in range(MAX_TENSOR_RANK)
    ])

    return [Disj([is_dyn, c3])], counter
Example #7
0
def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter):
    """
    The output shape differs from the input shape in the last dimension
    """
    assert isinstance(n.args[0], Node)

    embedding_dim = module_instance.embedding_dim  # number

    embedding_output, counter = gen_tvar(counter)
    symbols[n] = embedding_output
    embedding_input = symbols[n.args[0]]

    input_dyn = BinConstraintT(embedding_input, Dyn, op_eq)
    output_dyn = BinConstraintT(embedding_output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])
    c2 = []

    for i in range(1, MAX_TENSOR_RANK):
        new_dims, counter = gen_tensor_dims(i, counter)
        nat_constraints = gen_nat_constraints(new_dims)

        # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
        c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
                           BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] +
                          nat_constraints)
        c2.append(c_tensor_i)

    return [Disj([c1, Disj(c2)])], counter
Example #8
0
def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
    """
    Input and output shapes should be equal.
    Input should be consistent with the normalized_shape
    """
    assert isinstance(n.args[0], Node)
    output, counter = gen_tvar(counter)
    symbols[n] = output
    input = symbols[n.args[0]]

    input_dyn = BinConstraintT(input, Dyn, op_eq)
    output_dyn = BinConstraintT(output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])

    c2 = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims_rhs, counter = gen_tensor_dims(i, counter)
        nat_constraints = gen_nat_constraints(new_dims_rhs)

        c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
                           BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] +
                          add_layer_norm_constraints(new_dims_rhs, list(module_instance.normalized_shape)) +
                          nat_constraints)
        c2.append(c_tensor_i)


    return [Disj([c1, Disj(c2)])], counter
Example #9
0
def getitem_inference_rule(n: Node, symbols, constraints, counter):
    assert isinstance(n.args[0], Node)

    # dimension output case
    if isinstance(n.args[1], int):
        # create and store the new dimension variable
        get_item_output, counter = gen_dvar(counter)
        symbols[n] = get_item_output

        # retreive arg variables
        get_item_arg = symbols[n.args[0]]
        assert isinstance(get_item_arg, TVar)

        # if the input is dynamic, we accept any index and return
        # a dynamic dimension as output
        input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
        output_dyn = BinConstraintD(get_item_output, Dyn, op_eq)
        c1 = Conj([input_dyn, output_dyn])

        # if the input is a tensor,
        # generate a getItem constraint which will be expanded based on the
        # tensor dimension.

        c2 = [
            GetItem(i + 1, n.args[1], get_item_output, get_item_arg)
            for i in range(MAX_TENSOR_RANK)
        ]

        # since the output is a dimension, we make sure it's a natural number
        # added as a conjunction to the disjuction of c2
        c3 = BinConstraintD(0, get_item_output, op_leq)
        return [Disj([c1, Conj([Disj(c2), c3])])], counter

    # tensor output case
    elif isinstance(n.args[1], tuple):
        # create and store the new tensor variable
        get_item_output, counter = gen_tvar(counter)
        symbols[n] = get_item_output

        # retreive arg variables
        get_item_arg = symbols[n.args[0]]
        assert isinstance(get_item_arg, TVar)

        input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
        output_dyn = BinConstraintT(get_item_output, Dyn,
                                    op_eq)  # type: ignore[assignment]
        c1 = Conj([input_dyn, output_dyn])

        c2 = [
            GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg)
            for i in range(MAX_TENSOR_RANK)
        ]  # type: ignore[misc]

        return [Disj([c1, *c2])], counter

    else:
        raise RuntimeError('Method not yet implemented')
def no_broadcast_dim_with_index(d1: List[DVar], d2: List[DVar], d3: List[DVar],
                                d4: List[DVar], i: int):
    """
    Args:
        d1: inpput 1
        d2: inpput 2
        d3: simulated broadcasting for input 1
        d4: simulated broadcasting for input 2
        i: the rank of the resulting tensor addition

    Returns: Constraints for when no broadcasting occurs
    """
    return Conj([
        Disj([
            Conj([
                BinConstraintD(d1[i], 1, op_eq),
                BinConstraintD(d2[i], 1, op_eq)
            ]),
            Conj([
                BinConstraintD(d1[i], 1, op_neq),
                BinConstraintD(d2[i], 1, op_neq)
            ])
        ]),
        BinConstraintD(d1[i], d3[i], op_eq),
        BinConstraintD(d2[i], d4[i], op_eq)
    ])
Example #11
0
def flatten_inference_rule(n: Node, symbols, constraints, counter):
    assert isinstance(n.args[0], Node)

    # generate the new variable
    flattened, counter = gen_tvar(counter)
    symbols[n] = flattened

    input = symbols[n.args[0]]

    # set the default start and end dims
    start_dim = 1
    end_dim = -1

    if len(n.args) > 1:
        assert isinstance(n.args[1], int)
        start_dim = n.args[1]

    if len(n.args) > 2:
        assert isinstance(n.args[2], int)
        end_dim = n.args[2]

    c1 = BinConstraintT(input, Dyn, op_eq)
    c2 = BinConstraintT(flattened, Dyn, op_eq)
    both_dyn = Conj([c1, c2])

    const = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter)
        const.append(c)

    return [Disj([both_dyn, *const])], counter
def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar],
                                                       d2: List[DVar],
                                                       d11: List[DVar],
                                                       d12: List[DVar]):
    """
    Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.
    We look at all combinations for all dimendions in d1 and d2
    Args:
        d1: input1 dimensions
        d2: input2 dimensions
        d11: broadcasted input1 dimensions
        d12: broadcasted input2 dimensions

    Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions

    """

    size = len(d1)

    res2 = []

    for i in range(size):
        t1 = broadcast_dim(d1, d2, d11, d12, i)
        t2 = broadcast_dim(d2, d1, d12, d11, i)
        t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i)

        res2.append(Disj([t1, t2, t3]))

    return Conj(res2)
def generate_gub(constraint, counter):
    """
    Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound
    on dimensions
    """
    c1 = Conj([
        Disj([
            BinConstraintT(constraint.rhs1, Dyn, op_eq),
            BinConstraintT(constraint.rhs2, Dyn, op_eq)
        ]),
        BinConstraintT(constraint.res, Dyn, op_eq)
    ])

    [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter)

    return Disj([c1, c2, c3, c4, c5]), counter
Example #14
0
def bmm_inference_rule(n: Node, symbols, constraints, counter):
    """
    Constraints that match the input to a size 3 tensor
    and switch the dimensions according to the rules
    of batch multiplication
    """
    assert isinstance(n.args[0], Node)
    assert isinstance(n.args[1], Node)

    bmm_output, counter = gen_tvar(counter)
    symbols[n] = bmm_output

    bmm_input1 = symbols[n.args[0]]
    bmm_input2 = symbols[n.args[1]]

    dims_input1, counter = gen_tensor_dims(3, counter)
    dims_input2, counter = gen_tensor_dims(3, counter)

    inputs_dyn = Conj([
        BinConstraintT(bmm_input1, Dyn, op_eq),
        BinConstraintT(bmm_input2, Dyn, op_eq),
        BinConstraintT(bmm_output, Dyn, op_eq)
    ])

    input1_dyn = Conj([
        BinConstraintT(bmm_input1, Dyn, op_eq),
        BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
        BinConstraintT(bmm_output,
                       TensorType([dims_input2[0], Dyn, dims_input2[2]]),
                       op_eq)
    ])

    input2_dyn = Conj([
        BinConstraintT(bmm_input2, Dyn, op_eq),
        BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
        BinConstraintT(bmm_output,
                       TensorType([dims_input1[0], dims_input1[1], Dyn]),
                       op_eq)
    ])

    consistency_constraints = [
        BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)
    ]

    batch_size, counter = gen_dvar(counter)

    inputs_are_tensors = Conj([
        BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
        BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
        BinConstraintT(
            bmm_output,
            TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq),
        *consistency_constraints,
        DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])
    ])

    return [Disj([inputs_dyn, input1_dyn, input2_dyn,
                  inputs_are_tensors])], counter
def generate_calc_product(constraint, counter):
    """
    Transform flatten constraints
    """
    start = constraint.start
    end = constraint.end
    dims = constraint.dims_to_flatten
    flattened = constraint.flattened
    n = len(constraint.dims_to_flatten)

    # this will be evaluated right here
    boundary_check = (0 <= start and start < end and end <= n)

    c_boundary = T() if boundary_check else F()

    lhs = dims[0:start]
    rhs = dims[end:]
    mid = dims[start:end]

    all_possibilities = generate_all_int_dyn_dim_possibilities(mid)

    all_constraints = []

    for p in all_possibilities:
        p = list(p)
        # this tells us there is a dynamic variable
        contains_dyn = not (all([constraint.op == op_neq for constraint in p]))
        if contains_dyn:
            mid_var = [Dyn]
            total_constraints = lhs + mid_var + rhs
            if len(total_constraints) > 4:
                all_constraints.append(F())
            else:
                all_constraints.append(
                    Conj([
                        BinConstraintT(flattened,
                                       TensorType(lhs + mid_var + rhs), op_eq)
                    ] + p))
        else:
            new_var, counter = gen_dvar(counter)
            mid_eq_prod = Conj([
                BinConstraintD(new_var, Prod(mid), op_eq),
                BinConstraintD(new_var, Dyn, op_neq)
            ])
            mid_var = [new_var]
            total_constraints = lhs + mid_var + rhs
            if len(total_constraints) > 4:
                all_constraints.append(F())
            else:
                all_constraints.append(
                    Conj([
                        BinConstraintT(flattened,
                                       TensorType(lhs + mid_var +
                                                  rhs), op_eq), mid_eq_prod
                    ] + p))

    return Conj([Disj(all_constraints), c_boundary]), counter
def generate_disj(constraint, counter):
    """
    Transform disjunctions
    """
    new = []
    for c in constraint.disjuncts:
        new_c, counter = transform_constraint(c, counter)
        new.append(new_c)
    return Disj(new), counter
def apply_padding(e1_var: TVar, e11: BinConstraintT, e2: BinConstraintT,
                  e12: BinConstraintT, d2: List[DVar], d11: List[DVar],
                  d12: List[DVar], counter: int):
    """
    We are considering the possibility where one input has less dimensions than
    another input, so we apply padding to the broadcasted results

    Args:
        e1_var: Variable representing the first input where padding will be
        e11: constraint of the form e11 = Tensortype[d1, ..., dn]
        e2:  constraint of the form e2 = Tensortype[d1, ..., dn]
        e12: constraint of the form e11 = Tensortype[d1, ..., dn]
        d2: Tensor variables for the second input
        d11: Tensor variables for the broadcasted first input
        d12: Tensor variables for the broadcasted second input
        counter: variable tracking

    Returns: A new constraint whose goal is to apply padding to the broadcasted result

    """

    res = []

    # pad the shorter input with None so we can pass it to the broadcasting helper function
    for i in range(1, len(d2)):

        d1, counter = gen_tensor_dims(i, counter)

        nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12)

        e1 = BinConstraintT(e1_var, TensorType(d1), op_eq)

        simulate_padding = [None] * (len(d2) - i)

        assert len(simulate_padding + d1) == len(d2)

        broadcast_padding = []

        # for every padding size, we also consider broadcasting
        for j in range((len(d2) - i)):
            broadcast_padding.append(
                broadcast_dim(simulate_padding, d2, d11, d12, j, True))

        # we consider the possibilities for broadcasting for every dimension. Since we already
        # padded d1, we do not consider it while broadcasting
        all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(
            d1, d2[(len(d2) - i):], d11[(len(d2) - i):], d12[(len(d2) - i):])
        # combine all constraints into a conjunction
        c = Conj([
            e1, e11, e2, e12, *broadcast_padding,
            all_broadcasting_possibilities, *nat_constraints
        ])
        res.append(c)

    return Disj(res), counter
def generate_broadcasting(constraint, counter):
    """
    Transform broadcasting constraints
    """
    e11, e12 = constraint.res1, constraint.res2
    e1, e2 = constraint.input1, constraint.input2

    e1_dyn = BinConstraintT(e1, Dyn, op_eq)
    e2_dyn = BinConstraintT(e2, Dyn, op_eq)

    # Introduce dimensions
    e1_equal_e11 = BinConstraintT(e1, e11, op_eq)
    e2_equal_e12 = BinConstraintT(e2, e12, op_eq)

    # dyn possibility
    e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12])
    e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12])

    # tensor possibility
    # generate dimensions to create tensors of size 1
    final_tensor_1_constraint, _, _, nat_dims_1, counter = \
        gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter)

    # generate dimensions to create tensors of size 2
    final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \
        final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \
        gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter)

    # generate dimensions to create tensors of size 3
    final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \
        final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \
        gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter)

    # generate dimensions to create tensors of size 4
    final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \
        final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \
        gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter)

    final_result = Disj([
        e1_dyn_constraint, e2_dyn_constraint, final_tensor_1_constraint,
        final_tensor_2_constraint_no_padding,
        final_tensor_2_constraint_padding_arg1,
        final_tensor_2_constraint_padding_arg2,
        final_tensor_3_constraint_no_padding,
        final_tensor_3_constraint_padding_arg1,
        final_tensor_3_constraint_padding_arg2,
        final_tensor_4_constraint_no_padding,
        final_tensor_4_constraint_padding_arg1,
        final_tensor_4_constraint_padding_arg2
    ])

    return Conj(
        [final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3,
         *nat_dims_4]), counter
Example #19
0
def torch_dim_inference_rule(n: Node, symbols, constraints, counter):
    assert isinstance(n.args[0], Node)
    my_dim, counter = gen_dvar(counter)
    symbols[n] = my_dim
    input = symbols[n.args[0]]

    input_dyn = BinConstraintT(input, Dyn, op_eq)
    output_dyn = BinConstraintD(my_dim, Dyn, op_eq)

    c1 = []

    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)

        c_tensor_i = Conj([
            BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq),
            BinConstraintD(my_dim, i, op_eq)
        ])
        c1.append(c_tensor_i)

    return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter
Example #20
0
def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
    output, counter = gen_tvar(counter)
    symbols[n] = output
    input = symbols[n.args[0]]

    input_dyn = BinConstraintT(input, Dyn, op_eq)
    output_dyn = BinConstraintT(output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])

    c2 = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims_rhs, counter = gen_tensor_dims(i, counter)
        nat_constraints = gen_nat_constraints(new_dims_rhs)

        c_tensor_i = Conj([
            BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
            BinConstraintT(output, TensorType(new_dims_rhs), op_eq)
        ] + add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) +
                          nat_constraints)
        c2.append(c_tensor_i)
    return [Disj([c1, Disj(c2)])], counter
Example #21
0
def size_inference_rule(n: Node, symbols, constraints, counter):
    """
    The constraint is just lhs = rhs.
    Ex: size = input_ids.size()
    """

    if len(n.args) == 1:
        # generate the new variable
        size, counter = gen_tvar(counter)
        symbols[n] = size
        input = symbols[n.args[0]]
        c = BinConstraintT(input, size, op_eq)
        return [c], counter

    elif len(n.args) == 2:
        # TODO: review this rule; should input = dyn; output = dyn be included here?
        if isinstance(n.args[1], int):
            # generate the new variable
            size_index, counter = gen_dvar(counter)
            symbols[n] = size_index
            input = symbols[n.args[0]]
            c2 = [
                GetItem(i + 1, n.args[1], size_index, input)
                for i in range(MAX_TENSOR_RANK)
            ]
            c3 = BinConstraintD(0, size_index, op_leq)

            input_dyn = BinConstraintT(input, Dyn, op_eq)
            output_dyn = BinConstraintD(size_index, Dyn, op_eq)
            c1 = Conj([input_dyn, output_dyn])

            return [Disj([c1, Conj([Disj(c2), c3])])], counter

        else:
            raise NotImplementedError

    else:
        raise NotImplementedError
Example #22
0
def linear_constraints(n: Node, in_features, out_features, symbols, counter):
    linear_output, counter = gen_tvar(counter)
    symbols[n] = linear_output
    linear_input = symbols[n.args[0]]

    input_dyn = BinConstraintT(linear_input, Dyn, op_eq)
    output_dyn = BinConstraintT(linear_output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])

    c2 = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
        new_dims_rhs_2, counter = gen_tensor_dims(i, counter)

        nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)

        c_tensor_i = Conj([
            BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
            BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)
        ] + add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features,
                                   out_features) + nat_constraints)
        c2.append(c_tensor_i)
    return [Disj([c1, Disj(c2)])], counter
def generate_d_gub(constraint, counter):
    """
    Transform greatest upper bound for dimensions into equality constraints
    """
    c1 = Conj([
        BinConstraintD(constraint.rhs1, Dyn, op_eq),
        BinConstraintD(constraint.res, constraint.rhs2, op_eq)
    ])
    c2 = Conj([
        BinConstraintD(constraint.rhs2, Dyn, op_eq),
        BinConstraintD(constraint.res, constraint.rhs1, op_eq)
    ])
    c3 = Conj([
        BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq),
        BinConstraintD(constraint.res, constraint.rhs1, op_eq)
    ])
    return Disj([c1, c2, c3]), counter
def gen_all_reshape_possibilities(list_of_dims, target):
    """
    Consider all possibilities what the input dimensions could be (number or dynamic)
    Then generate the appropriate constraints using multiplication or mod depending on the possibility
    The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn
    for the input. Target is fixed because at most one dimension could be dyn.
    We have different cases for this.

    Args:
        list_of_dims: The input list of dimensions
        target: The tensor we want to reshape to

    Returns: A disjuncition of transformed reshape constraints

    """
    all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims)

    all_constraints = []

    for p in all_possibilities:
        to_multiply = []

        p = list(p)

        for constraint in p:
            assert isinstance(constraint, BinConstraintD)
            if constraint.op == op_neq:
                to_multiply.append(constraint.lhs)

        if not to_multiply:
            all_constraints.append(Conj(p))

        elif len(to_multiply) < len(list_of_dims):
            all_constraints.append(
                Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))]))
        else:
            all_constraints.append(
                Conj(
                    p +
                    [BinConstraintD(Prod(list_of_dims), Prod(target), op_eq)]))

    return Disj(all_constraints)
def generate_binconstraint_d(constraint, counter):
    """
    Transform binary constraints for dimensions
    """
    if constraint.op == op_precision:
        if isinstance(constraint.lhs, int):
            return BinConstraintD(constraint.lhs, constraint.rhs,
                                  op_eq), counter
        elif constraint.lhs == Dyn:
            return T(), counter

    elif constraint.op == op_consistency:
        return Disj([
            BinConstraintD(constraint.lhs, constraint.rhs, op_eq),
            BinConstraintD(constraint.rhs, Dyn, op_eq),
            BinConstraintD(constraint.lhs, Dyn, op_eq)
        ]), counter

    else:
        return constraint, counter
def calc_last_two_dims(constraint, d: List[DVar]):
    """
    Generates constraints for the last two dimensions of a convolution or a maxpool output
    Args:
        constraint: CalcConv or CalcMaxPool
        d: The list of output dimensions

    Returns: Constraints for calculating the last two dimensions of the output

    """

    assert isinstance(constraint, CalcConv) or isinstance(
        constraint, CalcMaxPool)

    b3 = constraint.matching_constraint[2]
    b4 = constraint.matching_constraint[3]

    b3_dyn = Conj(
        [BinConstraintD(d[2], Dyn, op_eq),
         BinConstraintD(b3, Dyn, op_eq)])
    b4_dyn = Conj(
        [BinConstraintD(d[3], Dyn, op_eq),
         BinConstraintD(b4, Dyn, op_eq)])

    d3_not_dyn = Conj(
        [BinConstraintD(d[2], Dyn, op_neq),
         BinConstraintD(b3, Dyn, op_neq)])
    d4_not_dyn = Conj(
        [BinConstraintD(d[3], Dyn, op_neq),
         BinConstraintD(b4, Dyn, op_neq)])

    # transform parameters into tuples incase they are not already
    padding = (constraint.padding, constraint.padding) \
        if isinstance(constraint.padding, int) else constraint.padding
    kernel = (constraint.kernel, constraint.kernel) \
        if isinstance(constraint.kernel, int) else constraint.kernel
    stride = (constraint.stride, constraint.stride) \
        if isinstance(constraint.stride, int) else constraint.stride
    dilation = (constraint.dilation, constraint.dilation) \
        if isinstance(constraint.dilation, int) else constraint.dilation

    f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add)
    f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub),
                        op_mul)
    f3 = BinConstraintD(
        BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0],
        op_div)
    f4 = BinConstraintD(f3, 1, op_add)

    c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])])

    f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add)
    f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub),
                         op_mul)
    f33 = BinConstraintD(
        BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1],
        op_div)
    f44 = BinConstraintD(f33, 1, op_add)

    c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])])

    return c4, c5
def generate_reshape(constraint, counter):
    """
    Transform reshape constraints
    """
    d, counter = gen_tensor_dims(4, counter)

    d1 = d[0]
    d2 = d[1]
    d3 = d[2]
    d4 = d[3]

    target = constraint.target.__args__

    is_fully_static = all([d != Dyn for d in target])

    # dynamic tensor
    c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq)
    c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq)
    c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq)
    c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]),
                                op_eq)
    c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]),
                                op_eq)

    d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq)
    d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq)

    d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq)
    d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq)

    d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
    d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq)

    d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
    d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq)

    nat_d1 = BinConstraintD(0, d1, op_leq)
    nat_d2 = BinConstraintD(0, d2, op_leq)
    nat_d3 = BinConstraintD(0, d3, op_leq)
    nat_d4 = BinConstraintD(0, d4, op_leq)

    if is_fully_static:
        # size 1 tensor
        c3_tensor1 = Disj([
            d1_eq_dyn,
            (Conj([d1_neq_dyn,
                   BinConstraintD(d1, Prod(target), op_eq)]))
        ])
        all_tensor_1 = Conj([c2_tensor1, c3_tensor1])

        # size 2 tensor
        all_tensor_2 = Conj(
            [c2_tensor2,
             gen_all_reshape_possibilities([d1, d2], target)])

        # size 3 tensor
        all_tensor_3 = Conj(
            [c2_tensor3,
             gen_all_reshape_possibilities([d1, d2, d3], target)])

        # size 4 tensor
        all_tensor_4 = Conj([
            c2_tensor4,
            gen_all_reshape_possibilities([d1, d2, d3, d4], target)
        ])

        return Conj([
            Disj([
                c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4
            ]), nat_d1, nat_d2, nat_d3, nat_d4
        ]), counter

    # then there must be exactly one occurrence of dyn
    else:
        new_target = []

        for n in target:
            if n != Dyn:
                new_target.append(n)

        # tensor 1
        c3_tensor1 = Disj([
            d1_eq_dyn,
            (Conj([d1_neq_dyn,
                   is_dim_div_by_target(new_target, d1)]))
        ])
        all_tensor_1 = Conj([c2_tensor1, c3_tensor1])

        # tensor 2
        c21 = Disj([d1_eq_dyn, d2_eq_dyn])
        c22 = Conj([
            d1_neq_dyn, d2_neq_dyn,
            is_dim_div_by_target(new_target, Prod([d1, d2]))
        ])
        all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])])

        # tensor 3
        c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn])
        c32 = Conj([
            d1_neq_dyn, d2_neq_dyn, d3_neq_dyn,
            is_dim_div_by_target(new_target, Prod([d1, d2, d3]))
        ])
        all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])])

        # tensor 4
        c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn])
        c42 = Conj([
            d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn,
            is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))
        ])
        all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])])

        return Conj([
            Disj([
                c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4
            ]), nat_d1, nat_d2, nat_d3, nat_d4
        ]), counter
def generate_binconstraint_t(constraint, counter):
    """
    Transform binary constraints for tensors
    """

    # precision constraints
    if constraint.op == op_precision:
        if constraint.lhs == Dyn:
            return T(), counter
        elif isinstance(constraint.lhs, TensorType):
            is_fully_static = all([d != Dyn for d in constraint.lhs.__args__])
            if is_fully_static:
                return BinConstraintT(constraint.lhs, constraint.rhs,
                                      op_eq), counter
            else:
                new_dims = []

                for _ in range(len(constraint.lhs.__args__)):
                    dim, counter = gen_dvar(counter)
                    new_dims.append(dim)

                new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for
                                       new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \
                                      [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \
                                      [BinConstraintD(1, new_dim, op_leq) for
                                       new_dim in new_dims]
                return Conj(new_dim_constraints), counter

    # matching
    elif constraint.op == op_matching:
        assert isinstance(constraint.rhs, TensorType)
        d1 = constraint.rhs.__args__[0]
        d2 = constraint.rhs.__args__[1]
        d3 = constraint.rhs.__args__[2]
        d4 = constraint.rhs.__args__[3]

        conj = [
            BinConstraintT(constraint.lhs, Dyn, op_eq),
            BinConstraintD(d1, Dyn, op_eq),
            BinConstraintD(d2, Dyn, op_eq),
            BinConstraintD(d3, Dyn, op_eq),
            BinConstraintD(d4, Dyn, op_eq)
        ]
        return Disj([
            Conj(conj),
            BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)
        ]), counter

    elif constraint.op == op_consistency:
        c_dyn = Disj([
            BinConstraintT(constraint.lhs, Dyn, op_eq),
            BinConstraintT(constraint.rhs, Dyn, op_eq)
        ])
        [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4
         ], counter = gen_consistency_constraints(constraint, counter)

        return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3,
                     c_tensor_4]), counter

    elif constraint.op == op_leq:
        assert isinstance(constraint.rhs, int)
        disj = []
        for i in range(1, constraint.rhs + 1):
            dims = []
            for j in range(1, i + 1):
                dim_var, counter = gen_dvar(counter)
                dims.append(dim_var)
            disj.append(BinConstraintT(constraint.lhs, TensorType(dims),
                                       op_eq))
        return Disj(disj), counter
    else:
        return constraint, counter
Example #29
0
def neq_inference_rule(n: Node, symbols, constraints, counter):
    """
    Translates to inconsistent in gradual types.
    To prove inequality, we should prove that
    tensors are either different sizes or
    disagree on at least one dimension

    This is a WIP (works when the condition
    is false. We are working on making this operation work
    when the condition is true as well)
    """
    assert isinstance(n.args[0], Node)
    assert isinstance(n.args[1], tuple)

    # implementing for size 3 and 4
    if len(n.args[1]) == 3:

        assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int)
        assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int)
        assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int)

        lhs = symbols[n.args[0]]

        b, counter = gen_tensor_dims(4, counter)
        input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq)

        d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
        d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
        d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]

        # dimensions not equal
        my_ne, counter = gen_bvar(counter)
        neq_1 = BinConstraintD(d1, b[0], op_neq)
        neq_2 = BinConstraintD(d2, b[1], op_neq)
        neq_3 = BinConstraintD(d3, b[2], op_neq)

        # dimensions inconsistent
        dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1])
        dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2])
        dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3])

        dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3])

        # we are covering size 3 and 4 only for now
        ne_constraint = Conj([input_is_size3, dims_inconsistent])

        my_ne, counter = gen_bvar(counter)
        equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)

    elif len(n.args[1]) == 4:

        assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int)
        assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int)
        assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int)
        assert isinstance(n.args[1][3], Node) or isinstance(n.args[1][3], int)

        lhs = symbols[n.args[0]]

        b1, counter = gen_dvar(counter)
        b2, counter = gen_dvar(counter)
        b3, counter = gen_dvar(counter)
        b4, counter = gen_dvar(counter)

        input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq)

        d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
        d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
        d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
        d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]]

        # dimensions not equal
        my_ne, counter = gen_bvar(counter)
        neq_1 = BinConstraintD(d1, b1, op_neq)
        neq_2 = BinConstraintD(d2, b2, op_neq)
        neq_3 = BinConstraintD(d3, b3, op_neq)
        neq_4 = BinConstraintD(d4, b4, op_neq)

        # dimensions to inconsistent
        dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1])
        dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2])
        dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3])
        dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4])

        dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4])

        ne_constraint = Conj([input_is_size4, dims_inconsistent])

        my_ne, counter = gen_bvar(counter)

        equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)

    else:
        raise NotImplementedError('Method not yet implemented')

    return [equality_constraint], counter