コード例 #1
0
ファイル: constraint_generator.py プロジェクト: Mu-L/pytorch
def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):

    embedding_output, counter = gen_tvar(counter)
    symbols[n] = embedding_output
    embedding_input = symbols[n.args[0]]

    input_dyn = BinConstraintT(embedding_input, Dyn, op_eq)
    output_dyn = BinConstraintT(embedding_output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])
    c2 = []

    for i in range(1, MAX_TENSOR_RANK):
        new_dims, counter = gen_tensor_dims(i, counter)
        nat_constraints = gen_nat_constraints(new_dims)

        # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
        c_tensor_i = Conj([
            BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
            BinConstraintT(embedding_output,
                           TensorType(new_dims + [embedding_dim]), op_eq)
        ] + nat_constraints)
        c2.append(c_tensor_i)

    return [Disj([c1, Disj(c2)])], counter
コード例 #2
0
ファイル: constraint_generator.py プロジェクト: Mu-L/pytorch
def transpose_inference_rule(n: Node, symbols, constraints, counter):
    """
    Can be considered as a sequence of two index selects, so we generate constraints accordingly
    """
    assert isinstance(n.args[0], Node)
    assert isinstance(n.args[1], int)
    assert isinstance(n.args[2], int)

    output, counter = gen_tvar(counter)
    symbols[n] = output

    from_arg = symbols[n.args[0]]
    assert isinstance(from_arg, TVar)

    # input and output are dyn
    is_dyn = Conj([
        BinConstraintT(from_arg, Dyn, op_eq),
        BinConstraintT(output, Dyn, op_eq)
    ])

    # or input is a tensor and we actually do the replacement
    c3 = Disj([
        Transpose(i + 1, from_arg, n.args[1], n.args[2], output)
        for i in range(MAX_TENSOR_RANK)
    ])

    return [Disj([is_dyn, c3])], counter
コード例 #3
0
def cumsum_inference_rule(n: Node, symbols, constraints, counter):
    """
    Input and output shapes should be equal
    We should verify that the index is valid
    """
    assert isinstance(n.args[0], Node)
    arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"]
    assert isinstance(arg_1, int)

    output, counter = gen_tvar(counter)
    symbols[n] = output
    input = symbols[n.args[0]]

    input_dyn = BinConstraintT(input, Dyn, op_eq)
    output_dyn = BinConstraintT(output, Dyn, op_eq)
    c1 = Conj([input_dyn, output_dyn])
    c2 = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims, counter = gen_tensor_dims(i, counter)

        nat_constraints = gen_nat_constraints(new_dims)

        c_tensor_i = Conj([
            BinConstraintT(input, TensorType(new_dims), op_eq),
            BinConstraintT(output, TensorType(new_dims), op_eq)
        ] + [range_check(arg_1, i)] + nat_constraints)

        c2.append(c_tensor_i)
    dyn_or_tensor = Disj([c1, Disj(c2)])
    return [dyn_or_tensor], counter
コード例 #4
0
def linear_inference_rule(n: Node, module_instance, symbols, constraints,
                          counter):
    """
    Input and output sizes should be the same except for the last dimension
    If the input is Dyn, then so should the output
    """
    assert isinstance(n.args[0], Node)
    linear_output, counter = gen_tvar(counter)
    symbols[n] = linear_output
    linear_input = symbols[n.args[0]]

    input_dyn = BinConstraintT(linear_input, Dyn, op_eq)
    output_dyn = BinConstraintT(linear_output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])

    c2 = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
        new_dims_rhs_2, counter = gen_tensor_dims(i, counter)

        # Todo: add back natural number constraints
        nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)

        c_tensor_i = Conj([
            BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
            BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)
        ] + add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2,
                                   module_instance) + nat_constraints)
        c2.append(c_tensor_i)

    return [Disj([c1, Disj(c2)])], counter
コード例 #5
0
def equality_inference_rule(n: Node, symbols, constraints, counter):
    """
    We generate the constraint: input = output
    """
    output, counter = gen_tvar(counter)
    symbols[n] = output

    if isinstance(n.args[0], Node):
        input = symbols[n.args[0]]
        if isinstance(input, TVar):
            return [BinConstraintT(input, output, op_eq)], counter

        # then we have dimension variables
        else:
            for arg in n.args:
                assert isinstance(symbols[arg], DVar)
        my_size = [symbols[arg] for arg in n.args]
        return [BinConstraintT(output, TensorType(my_size), op_eq)], counter

    elif isinstance(n.args[0], tuple):
        # then the tuple is the size
        assert len(n.args[0]) <= 4
        my_size = [symbols[arg] for arg in n.args[0]]
        return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
    else:
        raise NotImplementedError('Method not yet implemented')
コード例 #6
0
def flatten_inference_rule(n: Node, symbols, constraints, counter):
    assert isinstance(n.args[0], Node)

    # generate the new variable
    flattened, counter = gen_tvar(counter)
    symbols[n] = flattened

    input = symbols[n.args[0]]

    # set the default start and end dims
    start_dim = 1
    end_dim = -1

    if len(n.args) > 1:
        assert isinstance(n.args[1], int)
        start_dim = n.args[1]

    if len(n.args) > 2:
        assert isinstance(n.args[2], int)
        end_dim = n.args[2]

    c1 = BinConstraintT(input, Dyn, op_eq)
    c2 = BinConstraintT(flattened, Dyn, op_eq)
    both_dyn = Conj([c1, c2])

    const = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        c, counter = generate_flatten_constraints(start_dim, end_dim, input,
                                                  flattened, i, counter)
        const.append(c)

    return [Disj([both_dyn, *const])], counter
コード例 #7
0
    def generate_constraints_node(self, n: Node, counter):
        """
        Generate constraints the given node:
        Currently supported operations:
        - Reshape
        - Add
        - conv2d
        """

        if n.op == 'placeholder':
            x, counter = gen_tvar(counter)
            self.symbol_dict[n] = x
            c1 = BinConstraintT(n.type, x, op_precision)
            c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq)
            return [c1, c2], counter

        elif n.op == 'call_function':
            if n.target == getattr:
                assert getattr in _INFERENCE_RULES
                return _INFERENCE_RULES[n.target](n, self.traced,
                                                  self.symbol_dict,
                                                  self.constraints)

            elif n.target in _INFERENCE_RULES:
                return _INFERENCE_RULES[n.target](n, self.symbol_dict,
                                                  self.constraints, counter)
            else:
                # print(n)
                raise RuntimeError(
                    f'No inference rule registered for target {n.target}!')

        elif n.op == 'call_module':

            module_instance = self.traced.get_submodule(n.target)
            if type(module_instance) in _INFERENCE_RULES:
                return _INFERENCE_RULES[type(module_instance)](
                    n, module_instance, self.symbol_dict, self.constraints,
                    counter)
            else:
                raise RuntimeError(
                    f'No inference rule registered for class {type(module_instance)}!'
                )

        elif n.op == 'call_method':
            if n.target in _INFERENCE_RULES:
                return _INFERENCE_RULES[n.target](n, self.symbol_dict,
                                                  self.constraints, counter)
            else:
                raise RuntimeError(
                    f'No inference rule registered for target {n.target}!')

        # TODO: verify that no constraint should be generated here
        elif n.op == 'get_attr':
            return [], counter

        elif n.op == 'output':
            return [], counter

        else:
            raise NotImplementedError(f"Method {n.op} not yet implemented")
コード例 #8
0
def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints,
                              counter):
    """
    Input and output shapes should be equal.
    Input should be consistent with the normalized_shape
    """
    assert isinstance(n.args[0], Node)
    output, counter = gen_tvar(counter)
    symbols[n] = output
    input = symbols[n.args[0]]

    input_dyn = BinConstraintT(input, Dyn, op_eq)
    output_dyn = BinConstraintT(output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])

    c2 = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims_rhs, counter = gen_tensor_dims(i, counter)
        nat_constraints = gen_nat_constraints(new_dims_rhs)

        c_tensor_i = Conj([
            BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
            BinConstraintT(output, TensorType(new_dims_rhs), op_eq)
        ] + add_layer_norm_constraints(
            new_dims_rhs, list(module_instance.normalized_shape)) +
                          nat_constraints)
        c2.append(c_tensor_i)

    return [Disj([c1, Disj(c2)])], counter
コード例 #9
0
def conv2d_inference_rule(n: Node, module_instance, symbols, constraints,
                          counter):
    assert isinstance(n.args[0], Node)

    my_conv, counter = gen_tvar(counter)
    symbols[n] = my_conv
    input_var = symbols[n.args[0]]

    # dim vars
    [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)

    # c1 = Matching(input_var, TensorType([d1, d2, d3, d4]))
    c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)

    # c2 = DConsistency(module_instance.in_channels, d2)
    c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency)

    c3 = CalcConv(my_conv, input_var, module_instance.out_channels,
                  module_instance.kernel_size, module_instance.padding,
                  module_instance.stride, module_instance.dilation,
                  [d1, d2, d3, d4])

    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])

    return [c1, c2, c3, *nat_constraints], counter
コード例 #10
0
def index_select_inference_rule(n: Node, symbols, constraints, counter):
    """
    We constrain the second argument to a vector or Dyn.
    The output replaces the input with the shape of the vector
    at the position given by the index (first argument)
    """
    # print(n.args)
    assert isinstance(n.args[0], Node)
    assert isinstance(n.args[1], int)
    assert isinstance(n.args[2], Node)



    index_select, counter = gen_tvar(counter)
    symbols[n] = index_select

    dims, counter = gen_tensor_dims(1, counter)

    # equality constraint
    is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
    is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)

    c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select)
                                for i in range(MAX_TENSOR_RANK)])])
    c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
                             for i in range(MAX_TENSOR_RANK)])])

    return [Disj([c2, c3])], counter
コード例 #11
0
def embedding_inference_rule(n: Node, module_instance, symbols, constraints,
                             counter):
    """
    The output shape differs from the input shape in the last dimension
    """
    assert isinstance(n.args[0], Node)

    embedding_dim = module_instance.embedding_dim  # number

    embedding_output, counter = gen_tvar(counter)
    symbols[n] = embedding_output
    embedding_input = symbols[n.args[0]]

    input_dyn = BinConstraintT(embedding_input, Dyn, op_eq)
    output_dyn = BinConstraintT(embedding_output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])
    c2 = []

    for i in range(1, MAX_TENSOR_RANK):
        new_dims, counter = gen_tensor_dims(i, counter)
        nat_constraints = gen_nat_constraints(new_dims)

        # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
        c_tensor_i = Conj([
            BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
            BinConstraintT(embedding_output,
                           TensorType(new_dims + [embedding_dim]), op_eq)
        ] + nat_constraints)
        c2.append(c_tensor_i)

    return [Disj([c1, Disj(c2)])], counter
コード例 #12
0
def arange_inference_rule(n: Node, symbols, constraints, counter):
    start = 0
    step = 1

    if len(n.args) == 1:
        end = symbols[n.args[0]]
    else:
        raise NotImplementedError('Not yet implemented')

    # int((end - start) / step)
    d1, counter = gen_dvar(counter)
    size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq)
    arange, counter = gen_tvar(counter)
    symbols[n] = arange

    # either the a parameter is a number or it is Dyn
    c1 = Disj([BinConstraintD(end, Dyn, op_eq),
               BinConstraintD(start, Dyn, op_eq),
               BinConstraintD(step, Dyn, op_eq)])
    c2 = BinConstraintD(d1, Dyn, op_eq)
    both_dyn = Conj([c1, c2])

    c11 = Conj([BinConstraintD(end, Dyn, op_neq),
                BinConstraintD(start, Dyn, op_neq),
                BinConstraintD(step, Dyn, op_neq)])
    c22 = BinConstraintD(d1, Dyn, op_neq)
    both_numbers = Conj([c11, c22, size_constraint])

    return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
コード例 #13
0
def gt_inference_rule(n: Node, symbols, constraints, counter):
    assert isinstance(n.args[0], Node) or isinstance(n.args[0], int)
    assert isinstance(n.args[1], Node) or isinstance(n.args[1], int)

    # We make sure this node will not be used again. We do not
    # generate a constraint about that node. Only about the operands.

    e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
    e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]

    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
        if isinstance(e1, TVar) and isinstance(e2, TVar):
            gt_tensor, counter = gen_tvar(counter)
            symbols[n] = gt_tensor
            return gen_broadcasting_constraints(e1, e2, symbols, counter,
                                                gt_tensor)

        elif isinstance(e1, DVar) and isinstance(e2, DVar):
            # This is meant to be used for flow analysis only
            gt_constraint = BinConstraintD(e1, e2, op_gt)

            my_gt, counter = gen_bvar(counter)
            equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
            return [equality_constraint], counter

        else:
            raise RuntimeError('Sort Mismatch')

    elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
        if isinstance(e1, DVar):
            # This is meant to be used for flow analysis only
            gt_constraint = BinConstraintD(e1, e2, op_gt)

            my_gt, counter = gen_bvar(counter)
            equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
            return [equality_constraint], counter

        elif isinstance(e1, TVar) and isinstance(e2, int):
            # then we made the wrong assumption about the argument being a tensor
            # so we should fix the assumption
            warnings.warn(
                f'Made the wrong assumption for node {n}. Correctness not guaranteed.'
            )

            new_e1, counter = gen_dvar(counter)
            symbols[n.args[0]] = new_e1
            symbols[n.args[0]]

            gt_constraint = BinConstraintD(new_e1, e2, op_gt)

            my_gt, counter = gen_bvar(counter)
            equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
            return [equality_constraint], counter

        else:
            raise NotImplementedError('Method not yet implemented')

    else:
        raise NotImplementedError('Method not yet implemented')
コード例 #14
0
ファイル: constraint_generator.py プロジェクト: Mu-L/pytorch
def bmm_inference_rule(n: Node, symbols, constraints, counter):
    """
    Constraints that match the input to a size 3 tensor
    and switch the dimensions according to the rules
    of batch multiplication
    """
    assert isinstance(n.args[0], Node)
    assert isinstance(n.args[1], Node)

    bmm_output, counter = gen_tvar(counter)
    symbols[n] = bmm_output

    bmm_input1 = symbols[n.args[0]]
    bmm_input2 = symbols[n.args[1]]

    dims_input1, counter = gen_tensor_dims(3, counter)
    dims_input2, counter = gen_tensor_dims(3, counter)

    inputs_dyn = Conj([
        BinConstraintT(bmm_input1, Dyn, op_eq),
        BinConstraintT(bmm_input2, Dyn, op_eq),
        BinConstraintT(bmm_output, Dyn, op_eq)
    ])

    input1_dyn = Conj([
        BinConstraintT(bmm_input1, Dyn, op_eq),
        BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
        BinConstraintT(bmm_output,
                       TensorType([dims_input2[0], Dyn, dims_input2[2]]),
                       op_eq)
    ])

    input2_dyn = Conj([
        BinConstraintT(bmm_input2, Dyn, op_eq),
        BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
        BinConstraintT(bmm_output,
                       TensorType([dims_input1[0], dims_input1[1], Dyn]),
                       op_eq)
    ])

    consistency_constraints = [
        BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)
    ]

    batch_size, counter = gen_dvar(counter)

    inputs_are_tensors = Conj([
        BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
        BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
        BinConstraintT(
            bmm_output,
            TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq),
        *consistency_constraints,
        DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])
    ])

    return [Disj([inputs_dyn, input1_dyn, input2_dyn,
                  inputs_are_tensors])], counter
コード例 #15
0
def getitem_inference_rule(n: Node, symbols, constraints, counter):
    assert isinstance(n.args[0], Node)

    # dimension output case
    if isinstance(n.args[1], int):
        # create and store the new dimension variable
        get_item_output, counter = gen_dvar(counter)
        symbols[n] = get_item_output

        # retreive arg variables
        get_item_arg = symbols[n.args[0]]
        assert isinstance(get_item_arg, TVar)

        # if the input is dynamic, we accept any index and return
        # a dynamic dimension as output
        input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
        output_dyn = BinConstraintD(get_item_output, Dyn, op_eq)
        c1 = Conj([input_dyn, output_dyn])

        # if the input is a tensor,
        # generate a getItem constraint which will be expanded based on the
        # tensor dimension.

        c2 = [
            GetItem(i + 1, n.args[1], get_item_output, get_item_arg)
            for i in range(MAX_TENSOR_RANK)
        ]

        # since the output is a dimension, we make sure it's a natural number
        # added as a conjunction to the disjuction of c2
        c3 = BinConstraintD(0, get_item_output, op_leq)
        return [Disj([c1, Conj([Disj(c2), c3])])], counter

    # tensor output case
    elif isinstance(n.args[1], tuple):
        # create and store the new tensor variable
        get_item_output, counter = gen_tvar(counter)
        symbols[n] = get_item_output

        # retreive arg variables
        get_item_arg = symbols[n.args[0]]
        assert isinstance(get_item_arg, TVar)

        input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
        output_dyn = BinConstraintT(get_item_output, Dyn,
                                    op_eq)  # type: ignore[assignment]
        c1 = Conj([input_dyn, output_dyn])

        c2 = [
            GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg)
            for i in range(MAX_TENSOR_RANK)
        ]  # type: ignore[misc]

        return [Disj([c1, *c2])], counter

    else:
        raise RuntimeError('Method not yet implemented')
コード例 #16
0
def to_inference_rule(n: Node, symbols, constraints, counter):
    """
    We generate the constraint: input = output
    """
    assert isinstance(n.args[0], Node)
    output, counter = gen_tvar(counter)
    symbols[n] = output
    input = symbols[n.args[0]]
    return [BinConstraintT(input, output, op_eq)], counter
コード例 #17
0
def full_inference_rule(n: Node, symbols, constraints, counter):
    full, counter = gen_tvar(counter)
    symbols[n] = full
    res = []

    assert isinstance(n.args[0], Iterable)
    for arg in n.args[0]:
        res.append(symbols[arg])
    c = BinConstraintT(full, TensorType(list(res)), op_eq)  # type: ignore[arg-type]
    return [c], counter
コード例 #18
0
def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter):
    """
    Input and output shapes should be equal.
    """
    assert isinstance(n.args[0], Node)
    output, counter = gen_tvar(counter)
    symbols[n] = output
    input = symbols[n.args[0]]
    assert isinstance(input, TVar)
    return [BinConstraintT(input, output, op_eq)], counter
コード例 #19
0
def expand_inference_rule(n: Node, symbols, constraints, counter):
    """
    We generate the exact constraints as we do for tensor additions but we constraint
    the rank of this expression to be equal to len(n.args[1:]) so that only
    those cases get considered for the output
    """
    assert isinstance(n.args[0], Node)

    # define the output for expand
    expand, counter = gen_tvar(counter)
    symbols[n] = expand

    # since we do not have two nodes here, we will construct an argument variable
    e1 = symbols[n.args[0]]
    e2, counter = gen_tvar(counter)

    e2_nat_constraints = []
    for arg in n.args[1:]:
        assert isinstance(arg, Node) or isinstance(arg, int)
        if isinstance(arg, Node):
            assert isinstance(symbols[arg], DVar)
            e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq))

    e2_constraint = BinConstraintT(
        e2,
        TensorType([
            arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]
        ]), op_eq)

    constraints, counter = gen_broadcasting_constraints(
        e1, e2, symbols, counter, expand)

    # constraint the output size
    dims, counter = gen_tensor_dims(len(n.args[1:]), counter)
    nat_constraints = gen_nat_constraints(dims)
    c = [
        BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints,
        e2_constraint, *e2_nat_constraints
    ]
    constraints += c

    return constraints, counter
コード例 #20
0
def add_inference_rule(n: Node, symbols, constraints, counter):

    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
        if isinstance(symbols[n.args[0]], TVar) and isinstance(
                symbols[n.args[1]], TVar):
            my_add, counter = gen_tvar(counter)
            symbols[n] = my_add
            e1 = symbols[n.args[0]]
            e2 = symbols[n.args[1]]

            return gen_broadcasting_constraints(e1, e2, symbols, counter,
                                                my_add)
        else:
            raise NotImplementedError('Method not yet implemented')

    elif isinstance(n.args[0], Node) and isinstance(n.args[1], int):
        if isinstance(symbols[n.args[0]], TVar):
            my_add, counter = gen_tvar(counter)
            symbols[n] = my_add
            e1 = symbols[n.args[0]]
            return [BinConstraintT(my_add, e1, op_eq)], counter
        elif isinstance(symbols[n.args[0]], DVar):
            my_add, counter = gen_dvar(counter)
            symbols[n] = my_add
            e1 = symbols[n.args[0]]

            # we will propagate the runtime value here since this is regular addition
            c = Conj([
                BinConstraintD(my_add, BinConstraintD(e1, n.args[1], op_add),
                               op_eq),
                BinConstraintD(0, my_add, op_leq)
            ])
            return [c], counter

        else:
            raise NotImplementedError('Method not yet implemented')

    else:
        # TODO generate add constraints for scalar addition
        raise NotImplementedError('Addition not yet implemented')
コード例 #21
0
def mul_inference_rule(n: Node, symbols, constraints, counter):

    my_mul, counter = gen_tvar(counter)
    symbols[n] = my_mul

    # since in this case, we have scalar multiplication
    # the input shape should be the same as the output shape
    if isinstance(n.args[0], Node) and isinstance(n.args[1], float):
        # retrieve arg variables
        e1 = symbols[n.args[0]]
        return [BinConstraintT(my_mul, e1, op_eq)], counter
    else:
        raise NotImplementedError('Case not yet implemented')
コード例 #22
0
def reshape_inference_rule(n: Node, symbols, constraints, counter):
    assert isinstance(n.args[0], Node)

    # generate the new variable
    my_reshape, counter = gen_tvar(counter)
    symbols[n] = my_reshape

    src_var = symbols[n.args[0]]
    t2 = n.args[1]
    t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2])  # type: ignore[union-attr]
    c1 = BinConstraintT(my_reshape, t2_type, op_eq)  # type: ignore[union-attr]
    c2 = CanReshape(src_var, t2_type)

    return [c1, c2], counter
コード例 #23
0
def ne_inference_rule(n: Node, symbols, constraints, counter):
    """
    We generate the same constraints as we do for addition. We assume the arguments can only
    be tensors here, unlike addition where we have scalar addition.
    """

    # create and store the new variable
    my_add, counter = gen_tvar(counter)
    symbols[n] = my_add

    assert isinstance(n.args[0], Node)
    assert isinstance(n.args[1], Node)
    return gen_broadcasting_constraints(symbols[n.args[0]], symbols[n.args[1]],
                                        symbols, counter, my_add)
コード例 #24
0
def add_inference_rule(n: Node, symbols, constraints, counter):
    assert isinstance(n.args[0], Node)
    assert isinstance(n.args[1], Node)

    # create and store the new variable
    my_add, counter = gen_tvar(counter)
    symbols[n] = my_add

    # retrieve arg variables
    e1 = symbols[n.args[0]]
    e2 = symbols[n.args[1]]

    # additional vars that don't correspond to expressions
    e11, counter = gen_tvar(counter)
    e22, counter = gen_tvar(counter)

    # generate constraints
    c1 = TGreatestUpperBound(my_add, e11, e22)
    c2 = ApplyBroadcasting(e11, e22, e1, e2)
    c3 = BinConstraintT(e11, e22, op_consistency)

    # store constraints
    return [c1, c2, c3], counter
コード例 #25
0
def get_attr_inference_rule(n: Node, symbols, constraints, counter):
    """
    If the attribute is "device" then the tensor shape is preserved
    """
    assert isinstance(n.args[0], Node)
    assert isinstance(n.args[1], str)
    output, counter = gen_tvar(counter)
    symbols[n] = output

    input = symbols[n.args[0]]
    attr = n.args[1]

    if attr == 'device':
        return [BinConstraintT(input, output, op_eq)], counter
    else:
        raise NotImplementedError('Not yet implemented')
コード例 #26
0
def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter):
    assert isinstance(n.args[0], Node)
    maxpool, counter = gen_tvar(counter)
    symbols[n] = maxpool
    input_var = symbols[n.args[0]]

    # dim vars
    [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)

    c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)

    c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding,
                     module_instance.stride, module_instance.dilation, [d1, d2, d3, d4])

    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])

    return [c1, c2, *nat_constraints], counter
コード例 #27
0
def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter):
    assert isinstance(n.args[0], Node)

    avg_pool, counter = gen_tvar(counter)

    symbols[n] = avg_pool
    input_var = symbols[n.args[0]]

    # dim vars
    d1, counter = gen_dvar(counter)
    d2, counter = gen_dvar(counter)
    d3, counter = gen_dvar(counter)
    d4, counter = gen_dvar(counter)
    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
    c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
    c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq)

    return [c1, c2, *nat_constraints], counter
コード例 #28
0
def type_inference_rule(n: Node, symbols, constraints, counter):
    """
    We generate the constraint: input = output
    """
    assert isinstance(n.args[0], Node)
    assert isinstance(n.args[1], Node)

    output, counter = gen_tvar(counter)
    symbols[n] = output

    from_arg = symbols[n.args[0]]
    to_arg = symbols[n.args[1]]

    assert isinstance(from_arg, TVar)
    assert isinstance(to_arg, TVar)

    return [BinConstraintT(from_arg, to_arg, op_consistency),
            BinConstraintT(output, to_arg, op_eq)], counter
コード例 #29
0
def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
    assert isinstance(n.args[0], Node)

    # generate the new variable
    batchnorm_output, counter = gen_tvar(counter)
    symbols[n] = batchnorm_output
    batchnorm_input = symbols[n.args[0]]

    # dim vars
    d1, counter = gen_dvar(counter)
    d2, counter = gen_dvar(counter)
    d3, counter = gen_dvar(counter)
    d4, counter = gen_dvar(counter)

    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])

    c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching)
    c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq)
    return [c1, c2, *nat_constraints], counter
コード例 #30
0
def view_inference_rule(n: Node, symbols, constraints, counter):
    """
    Similar to reshape but with an extra condition on the strides
    """
    assert isinstance(n.args[0], Node)

    # generate the new variable
    my_view, counter = gen_tvar(counter)
    symbols[n] = my_view

    src_var = symbols[n.args[0]]
    t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]]  # target shape
    t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2])  # type: ignore[union-attr]
    c1 = BinConstraintT(my_view, t2_type, op_eq)
    c2 = CanReshape(src_var, t2_type)

    # TODO: add the extra check mentioned here:
    # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view

    return [c1, c2], counter