Ejemplo n.º 1
def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):

    embedding_output, counter = gen_tvar(counter)
    symbols[n] = embedding_output
    embedding_input = symbols[n.args[0]]

    input_dyn = BinConstraintT(embedding_input, Dyn, op_eq)
    output_dyn = BinConstraintT(embedding_output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])
    c2 = []

    for i in range(1, MAX_TENSOR_RANK):
        new_dims, counter = gen_tensor_dims(i, counter)
        nat_constraints = gen_nat_constraints(new_dims)

        # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
        c_tensor_i = Conj([
            BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
                           TensorType(new_dims + [embedding_dim]), op_eq)
        ] + nat_constraints)

    return [Disj([c1, Disj(c2)])], counter
def broadcast_types(t1, t2):
    if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var):
        return t1, t2

    if isinstance(t1, TensorType) and isinstance(t2, TensorType):
        s1 = len(t1.__args__)
        s2 = len(t2.__args__)

        new_t1 = list(t1.__args__)
        new_t2 = list(t2.__args__)

        # here, we make our tensors the same length
        if s1 > s2:
            for i in range(s1 - s2):
                new_t2.insert(0, 1)

        elif s2 > s1:
            for i in range(s2 - s1):
                new_t1.insert(0, 1)

        for i, (x, y) in enumerate(zip(new_t1, new_t2)):
            if x == 1:
                new_t1[i] = y
            elif y == 1:
                new_t2[i] = x

        (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2))

        return (t1, t2)
        raise TypeError(f'Cannot broadcast types {t1} and {t2}')
Ejemplo n.º 3
    def test_type_check_conv2D(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self, inplanes, planes, stride=1, norm_layer=None):
                super(BasicBlock, self).__init__()
                if norm_layer is None:
                    norm_layer = torch.nn.BatchNorm2d
                self.conv1 = conv3x3(inplanes, planes, stride)
                self.bn1 = norm_layer(planes)

            def forward(self, x: Dyn):
                identity = x
                out: TensorType((2, 2, Dyn, 4)) = self.conv1(x)
                out += identity
                return out

        B = BasicBlock(2, 2)
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)
        for n in graph.nodes:
            if n.op == 'placeholder':
                assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
            if n.op == 'call_function':
                assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
            if n.op == 'output':
                assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
            if n.op == 'call_module':
                assert n.type == TensorType((2, 2, Dyn, 4))
Ejemplo n.º 4
def equality_inference_rule(n: Node, symbols, constraints, counter):
    We generate the constraint: input = output
    output, counter = gen_tvar(counter)
    symbols[n] = output

    if isinstance(n.args[0], Node):
        input = symbols[n.args[0]]
        if isinstance(input, TVar):
            return [BinConstraintT(input, output, op_eq)], counter

        # then we have dimension variables
            for arg in n.args:
                assert isinstance(symbols[arg], DVar)
        my_size = [symbols[arg] for arg in n.args]
        return [BinConstraintT(output, TensorType(my_size), op_eq)], counter

    elif isinstance(n.args[0], tuple):
        # then the tuple is the size
        assert len(n.args[0]) <= 4
        my_size = [symbols[arg] for arg in n.args[0]]
        return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
        raise NotImplementedError('Method not yet implemented')
Ejemplo n.º 5
    def test_resnet50(self):
        gm_run = symbolic_trace(resnet50())
        sample_input = torch.randn(1, 3, 224, 224)

        # run our nodes

        gm_static = symbolic_trace(resnet50())

        for n in gm_static.graph.nodes:
            n.type = None

        g = GraphTypeChecker({}, gm_static)
        # here we are checking for consistency with fully dynamic nodes
        for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes):
            assert is_consistent(n1.type,

        # here we give the same input as to runtume
        gm_static_with_types = symbolic_trace(resnet50())

        # we initialize our placeholder
        for n in gm_static_with_types.graph.nodes:
            if n.op == 'placeholder':
                n.type = TensorType((1, 3, 224, 224))

        g = GraphTypeChecker({}, gm_static_with_types)
        for n1, n2 in zip(gm_static_with_types.graph.nodes,
            assert n1.type == TensorType(n2.meta['tensor_meta'].shape)
Ejemplo n.º 6
def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter):
    Input and output sizes should be the same except for the last dimension
    If the input is Dyn, then so should the output
    assert isinstance(n.args[0], Node)
    linear_output, counter = gen_tvar(counter)
    symbols[n] = linear_output
    linear_input = symbols[n.args[0]]

    input_dyn = BinConstraintT(linear_input, Dyn, op_eq)
    output_dyn = BinConstraintT(linear_output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])

    c2 = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
        new_dims_rhs_2, counter = gen_tensor_dims(i, counter)

        nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)

        c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
                           BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] +
                          add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, module_instance) +

    return [Disj([c1, Disj(c2)])], counter
Ejemplo n.º 7
def broadcast_types(t1, t2):
    if t1 == Dyn or t2 == Dyn:
        return t1, t2

    if isinstance(t1, TensorType) and isinstance(t2, TensorType):
        s1 = len(t1.__args__)
        s2 = len(t2.__args__)

        new_t1 = list(t1.__args__)
        new_t2 = list(t2.__args__)

        if abs(s1 - s2) > 1 or s1 == 0 or s2 == 0:
            raise TypeError(f'Cannot broadcast the tensors {t1} and {t2}')

        if s1 > s2:
            new_t2.insert(0, t1.__args__[0])

        elif s2 > s1:
            new_t1.insert(0, t2.__args__[0])

        for i, (x, y) in enumerate(zip(new_t1, new_t2)):
            if x == 1:
                new_t1[i] = y
            elif y == 1:
                new_t2[i] = x

        if tuple(new_t1) != t1.__args__ and tuple(new_t2) != t2.__args__:
            raise TypeError('In-place operations cannot not change shape')

        return TensorType(tuple(new_t1)), TensorType(tuple(new_t2))
        raise TypeError(f'Cannot broadcast types {t1} and {t2}')
def gen_consistency_constraints(constraint: Constraint, counter: int):
        constraint: Consistency constraint on tensors
        counter: for variable tracking

    Returns: Equality and consistency constraints on dimensions


    all_constraints = []

    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
        new_dims_rhs_2, counter = gen_tensor_dims(i, counter)

        nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)

        c_tensor_i = Conj([
            BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq),
            BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)
        ] + [
            BinConstraintD(d1, d2, op_consistency)
            for d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)
        ] + nat_constraints)


    return all_constraints, counter
Ejemplo n.º 9
def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter):
    The output shape differs from the input shape in the last dimension
    assert isinstance(n.args[0], Node)

    embedding_dim = module_instance.embedding_dim  # number

    embedding_output, counter = gen_tvar(counter)
    symbols[n] = embedding_output
    embedding_input = symbols[n.args[0]]

    input_dyn = BinConstraintT(embedding_input, Dyn, op_eq)
    output_dyn = BinConstraintT(embedding_output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])
    c2 = []

    for i in range(1, MAX_TENSOR_RANK):
        new_dims, counter = gen_tensor_dims(i, counter)
        nat_constraints = gen_nat_constraints(new_dims)

        # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
        c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
                           BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] +

    return [Disj([c1, Disj(c2)])], counter
Ejemplo n.º 10
    def test_symbolic_add_with_broadcast(self):
        class M(torch.nn.Module):
            def forward(self, x: TensorType((1, 2, 3, Dyn)),
                        y: TensorType((2, 3, 4))):
                return torch.add(x, y)

        module = M()
        symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
        tc = GraphTypeChecker({}, symbolic_traced)
        r = Refine(symbolic_traced)

        assert r.constraints == [
            Equality(1, 1), Equality(2, 2),
            Equality(3, 3)
        # note that there is no equality constraint between dyn and 4 because
        # dyn could be 4 or 1


        expected_ph_types = [
            TensorType((1, 2, 3, sympy.symbols('~0'))),
            TensorType((2, 3, 4)),
            TensorType((1, 2, 3, sympy.symbols('~1'))),
            TensorType((1, 2, 3, sympy.symbols('~1')))
        expected_iter = iter(expected_ph_types)

        for n in symbolic_traced.graph.nodes:
            assert n.type == next(expected_iter)
Ejemplo n.º 11
def cumsum_inference_rule(n: Node, symbols, constraints, counter):
    Input and output shapes should be equal
    We should verify that the index is valid
    assert isinstance(n.args[0], Node)
    arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"]
    assert isinstance(arg_1, int)

    output, counter = gen_tvar(counter)
    symbols[n] = output
    input = symbols[n.args[0]]

    input_dyn = BinConstraintT(input, Dyn, op_eq)
    output_dyn = BinConstraintT(output, Dyn, op_eq)
    c1 = Conj([input_dyn, output_dyn])
    c2 = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims, counter = gen_tensor_dims(i, counter)

        nat_constraints = gen_nat_constraints(new_dims)

        c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq),
                           BinConstraintT(output, TensorType(new_dims), op_eq)] +
                          [range_check(arg_1, i)] + nat_constraints)

    dyn_or_tensor = Disj([c1, Disj(c2)])
    return [dyn_or_tensor], counter
Ejemplo n.º 12
def expand_inference_rule(n: Node, symbols, constraints, counter):
    We generate the exact constraints as we do for tensor additions but we constraint
    the rank of this expression to be equal to len(n.args[1:]) so that only
    those cases get considered for the output
    assert isinstance(n.args[0], Node)

    # define the output for expand
    expand, counter = gen_tvar(counter)
    symbols[n] = expand

    # since we do not have two nodes here, we will construct an argument variable
    e1 = symbols[n.args[0]]
    e2, counter = gen_tvar(counter)

    e2_nat_constraints = []
    for arg in n.args[1:]:
        assert isinstance(arg, Node) or isinstance(arg, int)
        if isinstance(arg, Node):
            assert isinstance(symbols[arg], DVar)
            e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq))

    e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq)

    constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand)

    # constraint the output size
    dims, counter = gen_tensor_dims(len(n.args[1:]), counter)
    nat_constraints = gen_nat_constraints(dims)
    c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints]
    constraints += c

    return constraints, counter
Ejemplo n.º 13
def create_equality_constraints_for_broadcasting(e1: TVar, e2: TVar, e11: TVar,
                                                 e12: TVar, d1: List[DVar],
                                                 d2: List[DVar],
                                                 d11: List[DVar],
                                                 d12: List[DVar]):
    Create equality constraints for when no broadcasting occurs
        e1: Input 1
        e2: Input 2
        e11: Broadcasted input 1
        e12: Broadcasted input 2
        d1: Variables that store dimensions for e1
        d2: Variables that store dimensions for e2
        d11: Variables that store dimensions for e11
        d12: Variables that store dimensions for e22

    Returns: Four equality constraints


    e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq)
    e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq)
    e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq)
    e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq)
    return [e1_tensor, e11_tensor, e2_tensor, e12_tensor]
Ejemplo n.º 14
def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
    Input and output shapes should be equal.
    Input should be consistent with the normalized_shape
    assert isinstance(n.args[0], Node)
    output, counter = gen_tvar(counter)
    symbols[n] = output
    input = symbols[n.args[0]]

    input_dyn = BinConstraintT(input, Dyn, op_eq)
    output_dyn = BinConstraintT(output, Dyn, op_eq)

    c1 = Conj([input_dyn, output_dyn])

    c2 = []
    for i in range(1, MAX_TENSOR_RANK + 1):
        new_dims_rhs, counter = gen_tensor_dims(i, counter)
        nat_constraints = gen_nat_constraints(new_dims_rhs)

        c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
                           BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] +
                          add_layer_norm_constraints(new_dims_rhs, list(module_instance.normalized_shape)) +

    return [Disj([c1, Disj(c2)])], counter
Ejemplo n.º 15
def transform_get_item_tensor(constraint, counter):
    When the index is a tuple, then the output will be a tensor
    TODO: we have to check if this is the case for all HF models

    The cases we are covrering here are a tuple with one of:
     - slice with default argument
     - None

     None appends 1 to the input tensor dimensions
     so each occurrence of 'None' increases the rank by 1

     slice with default arguments does not change the rank
    assert isinstance(constraint.index_tuple, tuple)

    # generate a result tensor of the expected size
    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
    nat_constraints = gen_nat_constraints(dims)

    # generate a place-holder list of the right rank
    # where "slice" does not contribute to the rank and "None" does
    none_c = constraint.index_tuple.count(None)
    resulting_tensor_dims = (none_c + len(dims)) * [None]

    dim_index = 0
    for i in range(len(constraint.index_tuple)):

        # append 1 to the right location of the resulting tensor
        if constraint.index_tuple[i] is None:
            resulting_tensor_dims[i] = 1

        elif constraint.index_tuple[i] == slice(None, None, None):

            raise NotImplementedError('Method not yet implemented')

    # append the remaining dimensions to the right location
    dim_index = 0
    for i in range(len(resulting_tensor_dims)):
        if resulting_tensor_dims[i] is None:
            resulting_tensor_dims[i] = dims[dim_index]
            dim_index += 1

    # check if the index is valid
    is_valid_index = valid_index_tensor(constraint.index_tuple, dims)

    # check if the resulting tensor is within bounds
    if len(resulting_tensor_dims) > 4:
        return F(), counter

        constraints = [
            BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
            BinConstraintT(constraint.res, TensorType(resulting_tensor_dims),
                           op_eq), *nat_constraints, is_valid_index
        return Conj(constraints), counter
Ejemplo n.º 16
def generate_calc_product(constraint, counter):
    Transform flatten constraints
    start = constraint.start
    end = constraint.end
    dims = constraint.dims_to_flatten
    flattened = constraint.flattened
    n = len(constraint.dims_to_flatten)

    # this will be evaluated right here
    boundary_check = (0 <= start and start < end and end <= n)

    c_boundary = T() if boundary_check else F()

    lhs = dims[0:start]
    rhs = dims[end:]
    mid = dims[start:end]

    all_possibilities = generate_all_int_dyn_dim_possibilities(mid)

    all_constraints = []

    for p in all_possibilities:
        p = list(p)
        # this tells us there is a dynamic variable
        contains_dyn = not (all([constraint.op == op_neq for constraint in p]))
        if contains_dyn:
            mid_var = [Dyn]
            total_constraints = lhs + mid_var + rhs
            if len(total_constraints) > 4:
                                       TensorType(lhs + mid_var + rhs), op_eq)
                    ] + p))
            new_var, counter = gen_dvar(counter)
            mid_eq_prod = Conj([
                BinConstraintD(new_var, Prod(mid), op_eq),
                BinConstraintD(new_var, Dyn, op_neq)
            mid_var = [new_var]
            total_constraints = lhs + mid_var + rhs
            if len(total_constraints) > 4:
                                       TensorType(lhs + mid_var +
                                                  rhs), op_eq), mid_eq_prod
                    ] + p))

    return Conj([Disj(all_constraints), c_boundary]), counter
Ejemplo n.º 17
 def test_precision(self):
     Test the consistency relation.
     self.assertTrue(is_more_precise(TensorType((1, 2, 3)), TensorType((1, Dyn, 3))))
     self.assertTrue(is_more_precise(int, Dyn))
     self.assertTrue(is_more_precise(int, int))
     self.assertFalse(is_more_precise(TensorType((1, 2, 3)), TensorType((1, 2, 3, 5))))
     self.assertFalse(is_more_precise(TensorType((1, 2, 3)), int))
Ejemplo n.º 18
    def test_flatten_fully_static(self):
        annotation_list = [
            TensorType((2, 5, 6, 9)),
            TensorType((10, 15, 13, 14)),
            TensorType((10, Dyn, 13, 14)),
            TensorType((Dyn, Dyn, Dyn, 10))
        input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                      (10, 15, 13, 14), (2, 2, 10, 10)]

        intermediate_list = [
            Dyn, (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14),
            (2, 2, 10, 10)

        start_dim = [1, 2, 1, 2, 0]
        end_dim = [1, 3, 3, 3, -2]

        for i in range(5):
            annotation = annotation_list[i]
            input = input_list[i]

            # intermediate_type = intermediate_list[i]

            class BasicBlock(torch.nn.Module):
                def __init__(self, start, end):
                    super(BasicBlock, self).__init__()
                    self.start = start
                    self.end = end

                def forward(self, x):
                    out = torch.flatten(x, self.start, self.end)
                    return out

            B = BasicBlock(start_dim[i], end_dim[i])
            ast_rewriter = RewritingTracer()
            graph = ast_rewriter.trace(B)
            traced = GraphModule(ast_rewriter.root, graph, "gm")

            # annotate our argument
            for n in graph.nodes:
                if n.op == 'placeholder':
                    n.type = annotation

            b = B.forward(torch.rand(input))
            tc = GraphTypeChecker({}, traced)

            for n in graph.nodes:
                if n.op == 'output':
                    assert is_consistent(n.type, TensorType(b.size()))
Ejemplo n.º 19
def index_select_inference_rule(n: Node, symbols, constraints, counter):
    We constrain the second argument to a vector or Dyn.
    The output replaces the input with the shape of the vector
    at the position given by the index (first argument)
    # print(n.args)
    assert isinstance(n.args[0], Node)
    assert isinstance(n.args[1], int)
    assert isinstance(n.args[2], Node)

    index_select, counter = gen_tvar(counter)
    symbols[n] = index_select

    dims, counter = gen_tensor_dims(1, counter)

    # equality constraint
    is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
    is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)

    c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select)
                                for i in range(MAX_TENSOR_RANK)])])
    c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
                             for i in range(MAX_TENSOR_RANK)])])

    return [Disj([c2, c3])], counter
Ejemplo n.º 20
def transpose_inference_rule(n: Node):
    We check that dimentions for the transpose operations
    are within range of the tensor type of the node
    if n.target == torch.transpose:
        assert isinstance(n.args[0], Node)
        t = n.args[0].type

        assert isinstance(n.args[1], int)
        assert isinstance(n.args[2], int)
        dim1, dim2 = n.args[1], n.args[2]

        if t == Dyn:
            n.type = Dyn
            return n.type

        elif isinstance(t, TensorType):
            if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__):
                new_type = list(t.__args__)
                new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1]
                final = TensorType(new_type)
                n.type = get_greatest_upper_bound(n.type, final)
                return n.type
                raise TypeError(
                    f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}'
            raise TypeError(
                f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
Ejemplo n.º 21
def view_inference_rule(n: Node, symbols, constraints, counter):
    Similar to reshape but with an extra condition on the strides
    assert isinstance(n.args[0], Node)

    # generate the new variable
    my_view, counter = gen_tvar(counter)
    symbols[n] = my_view

    src_var = symbols[n.args[0]]
    t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]]  # target shape
    t2_type = []
    num_constraints = []

    for t in t2:
        if t == -1:
            var, counter = gen_dvar(counter)
            num_constraints.append(BinConstraintD(var, Dyn, op_neq))

            num_constraints.append(BinConstraintD(t, Dyn, op_neq))

    t2_type = TensorType(t2_type)  # type: ignore[assignment]

    c1 = BinConstraintT(my_view, t2_type, op_eq)
    c2 = CanReshape(src_var, t2_type)

    # TODO: add the extra check mentioned here:
    # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view

    return [c1, c2] + num_constraints, counter  # type: ignore[operator]
Ejemplo n.º 22
def transform_get_item(constraint, counter):
    generate an equality of the form:
    t = [a1, ..., an]
    then generate constraints that check if the given index is valid
    given this particular tensor size.
    If the index is valid, generate a constraint to get the item
    Note that we already handled the Dyn input case in the previous
        constraint: GetItem which assumes we are getting an item from a tensor (not Dyn)
        counter: variable tracking
    Returns: simplified constraints for GetItem

    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
    nat_constraints = gen_nat_constraints(dims)

    is_valid_index = valid_index(constraint.index, dims)

    all_constraints = [
        BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
        *nat_constraints, is_valid_index

    # if the index is valid, we generate a constraint for getting an item
    # otherwise this clause will have been UNSAT due to the wrong index
    if is_valid_index == T():
            BinConstraintD(constraint.res, dims[constraint.index], op_eq))

    return Conj(all_constraints), counter
Ejemplo n.º 23
def transpose_inference_rule(n: Node):
    if n.target == torch.transpose:
        assert isinstance(n.args[0], Node)
        t = n.args[0].type

        assert isinstance(n.args[1], int)
        assert isinstance(n.args[2], int)
        dim1, dim2 = n.args[1], n.args[2]

        if t == Dyn:
            n.type = Dyn
            return n.type

        elif isinstance(t, TensorType):

            if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__):
                new_type = list(t.__args__)
                new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1]
                final = TensorType(new_type)
                n.type = final
                return n.type
                raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
            raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
Ejemplo n.º 24
def generate_calc_conv(constraint, counter):
    d, counter = gen_tensor_dims(4, counter)
    conv_result = TensorType([d[0], d[1], d[2], d[3]])

    # the convolution result is a tensor of size 4
    c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq)

    # the second dimension of the output is equal to the output channels
    c2 = Conj([
        BinConstraintD(d[1], constraint.c_out, op_eq),
        BinConstraintD(d[1], Dyn, op_neq)

    # the input corresponds to the output in the first dimension of the convolution
    c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)

    c4, c5 = calc_last_two_dims(constraint, d)

    leq_constraints = Conj([
        BinConstraintD(0, d[0], op_leq),
        BinConstraintD(0, d[1], op_leq),
        BinConstraintD(0, d[2], op_leq),
        BinConstraintD(0, d[3], op_leq)

    return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
Ejemplo n.º 25
def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter):
    assert isinstance(n.args[0], Node)

    my_conv, counter = gen_tvar(counter)
    symbols[n] = my_conv
    input_var = symbols[n.args[0]]

    # dim vars
    [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)

    # c1 = Matching(input_var, TensorType([d1, d2, d3, d4]))
    c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)

    # c2 = DConsistency(module_instance.in_channels, d2)
    c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency)

    c3 = CalcConv(my_conv, input_var,
                  module_instance.dilation, [d1, d2, d3, d4])

    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])

    return [c1, c2, c3, *nat_constraints], counter
Ejemplo n.º 26
def conv2d_inference_rule(n: Node, module_instance):
    Given a Conv2D instance and a node check the following conditions:
    - the input type can be expanded to a size 4 tensor: t =  (x_1, x_2, H, W)
    - the current node type can be expanded to a size 4 tensor: t' =  (x_1', x_2', x_3', x_4')
    - x_2 is consistent with the module's in_channels
    - let o = (x_1, out_channels, H_out, W_out)
    then the output is the greatest upper bound of o and the existing node type t'.
    assert isinstance(n.args[0], Node)
    n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4)
    arg_type = n.args[0].type
    curr_node_type = expand_to_tensor_dim(n.type, 4)

    if is_consistent(arg_type.__args__[1], module_instance.in_channels):
        w_in = arg_type.__args__[3]
        h_in = arg_type.__args__[2]
        h_out = calculate_out_dimension(h_in, module_instance, 0)
        w_out = calculate_out_dimension(w_in, module_instance, 1)
        new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out))
        gub = get_greatest_upper_bound(new_type, curr_node_type)
        n.type = gub
        return n.type
        raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}')
Ejemplo n.º 27
    def test_type_check_add_with_broadcast(self):
        class M(torch.nn.Module):
            def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))):
                return torch.add(x, y)
        module = M()
        symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
        tc = GraphTypeChecker({}, symbolic_traced)
        expected_ph_types = [TensorType((1, 2, 3, Dyn)),
                             TensorType((1, 2, 3, 4)),
                             TensorType((1, 2, 3, Dyn)),
                             TensorType((1, 2, 3, Dyn))]
        expected_iter = iter(expected_ph_types)

        for n in symbolic_traced.graph.nodes:
            assert n.type == next(expected_iter)
Ejemplo n.º 28
def arange_inference_rule(n: Node, symbols, constraints, counter):
    start = 0
    step = 1

    if len(n.args) == 1:
        end = symbols[n.args[0]]
        raise NotImplementedError('Not yet implemented')

    # int((end - start) / step)
    d1, counter = gen_dvar(counter)
    size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq)
    arange, counter = gen_tvar(counter)
    symbols[n] = arange

    # either the a parameter is a number or it is Dyn
    c1 = Disj([BinConstraintD(end, Dyn, op_eq),
               BinConstraintD(start, Dyn, op_eq),
               BinConstraintD(step, Dyn, op_eq)])
    c2 = BinConstraintD(d1, Dyn, op_eq)
    both_dyn = Conj([c1, c2])

    c11 = Conj([BinConstraintD(end, Dyn, op_neq),
                BinConstraintD(start, Dyn, op_neq),
                BinConstraintD(step, Dyn, op_neq)])
    c22 = BinConstraintD(d1, Dyn, op_neq)
    both_numbers = Conj([c11, c22, size_constraint])

    return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
Ejemplo n.º 29
def substitute_solution_one_type(mapping, t):
    Apply the most general unifier to a type
    if isinstance(t, Var):
        if t in mapping.keys():
            return mapping[t]
            return t

    elif isinstance(t, TensorType):
        new_type = []
        for typ in t.__args__:
            if typ in mapping.keys():
        return TensorType(tuple(new_type))

    elif isinstance(t, list):
        new_type = []
        for typ in t:
            new_type.append(substitute_solution_one_type(mapping, typ))
        return new_type

    elif isinstance(t, tuple):
        new_type = []
        for typ in t:
            new_type.append(substitute_solution_one_type(mapping, typ))
        return tuple(new_type)

        return t
Ejemplo n.º 30
    def test_type_check_transpose_true(self):
        class M(torch.nn.Module):
            def forward(self, x: TensorType((1, 2, 3, 5))):
                return torch.transpose(x, 0, 1)

        module = M()
        symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
        tc = GraphTypeChecker({}, symbolic_traced)

        for n in symbolic_traced.graph.nodes:
            if n.op == 'call_function':
                assert n.type == TensorType([2, 1, 3, 5])
            if n.op == 'output':
                assert n.type == TensorType([2, 1, 3, 5])
            if n.op == 'x':
                assert n.placeholder == TensorType([1, 2, 3, 5])