def bn2d_inference_rule(n: Node, module_instance): """ Given a BatchNorm2D instance and a node check the following conditions: - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4) - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') - t is consistent with t' - x_2 is consistent with the module's num_features - x_2' is consistent with the module's num_features output type: the more precise type of t and t' """ assert isinstance(n.args[0], Node) n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) arg_type = n.args[0].type n.type = expand_to_tensor_dim(n.type, 4) # we check the conditions on the incoming argument # and any existing annotation # we also check for consistency between both annotations if is_consistent(arg_type.__args__[1], module_instance.num_features) and \ is_consistent(n.type.__args__[1], module_instance.num_features) and \ is_consistent(arg_type, n.type): # we choose the more precise type # to be the node type # so if an incoming argument has more type information # we set this node's type to be the argument type n.type = get_greatest_upper_bound(arg_type, n.type) return n.type else: raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}')
def test_consistency(self): """ Test the consistency relation. """ self.assertTrue(is_consistent(TensorType((1, 2, 3)), TensorType((1, Dyn, 3)))) self.assertTrue(is_consistent(int, Dyn)) self.assertTrue(is_consistent(int, int)) self.assertFalse(is_consistent(TensorType((1, 2, 3)), TensorType((1, 2, 3, 5)))) self.assertFalse(is_consistent(TensorType((1, 2, 3)), int))
def add_inference_rule(n: Node): assert isinstance(n.args[0], Node) assert isinstance(n.args[1], Node) t1 = n.args[0].type t2 = n.args[1].type # handle scalar addition if t1 == int and isinstance(t2, TensorType): n.type = t2 return n.type elif t2 == int and isinstance(t1, TensorType): n.type = t1 return n.type (new_t1, new_t2) = broadcast_types(t1, t2) n.args[0].type = new_t1 n.args[1].type = new_t2 if is_consistent(new_t1, new_t2): # we return the more precise type if is_more_precise(new_t1, new_t2): n.type = new_t2 else: n.type = new_t1 return n.type else: raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.' f' Types should match ')
def conv2d_inference_rule(n: Node, module_instance): """ Given a Conv2D instance and a node check the following conditions: - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W) - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') - x_2 is consistent with the module's in_channels - let o = (x_1, out_channels, H_out, W_out) then the output is the greatest upper bound of o and the existing node type t'. """ assert isinstance(n.args[0], Node) n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) arg_type = n.args[0].type curr_node_type = expand_to_tensor_dim(n.type, 4) if is_consistent(arg_type.__args__[1], module_instance.in_channels): w_in = arg_type.__args__[3] h_in = arg_type.__args__[2] h_out = calculate_out_dimension(h_in, module_instance, 0) w_out = calculate_out_dimension(w_in, module_instance, 1) new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out)) gub = get_greatest_upper_bound(new_type, curr_node_type) n.type = gub return n.type else: raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}')
def test_resnet50(self): gm_run = symbolic_trace(resnet50()) sample_input = torch.randn(1, 3, 224, 224) # run our nodes ShapeProp(gm_run).propagate(sample_input) gm_static = symbolic_trace(resnet50()) for n in gm_static.graph.nodes: n.type = None g = GraphTypeChecker({}, gm_static) g.type_check() # here we are checking for consistency with fully dynamic nodes for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes): assert is_consistent(n1.type, TensorType(n2.meta['tensor_meta'].shape)) # here we give the same input as to runtume gm_static_with_types = symbolic_trace(resnet50()) # we initialize our placeholder for n in gm_static_with_types.graph.nodes: if n.op == 'placeholder': n.type = TensorType((1, 3, 224, 224)) g = GraphTypeChecker({}, gm_static_with_types) g.type_check() for n1, n2 in zip(gm_static_with_types.graph.nodes, gm_run.graph.nodes): assert n1.type == TensorType(n2.meta['tensor_meta'].shape)
def broadcast_types(t1, t2): if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): return t1, t2 if isinstance(t1, TensorType) and isinstance(t2, TensorType): s1 = len(t1.__args__) s2 = len(t2.__args__) new_t1 = list(t1.__args__) new_t2 = list(t2.__args__) # here, we make our tensors the same length if s1 > s2: for i in range(s1 - s2): new_t2.insert(0, 1) elif s2 > s1: for i in range(s2 - s1): new_t1.insert(0, 1) for i, (x, y) in enumerate(zip(new_t1, new_t2)): if x == 1: new_t1[i] = y elif y == 1: new_t2[i] = x (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) if not is_consistent(t1, t2): raise TypeError return (t1, t2) else: raise TypeError(f'Cannot broadcast types {t1} and {t2}')
def add_inference_rule(n: Node): """ Apply the addition inference rule. This includes: - scalar addition - broadcasting semantics Note that we always return the least precise type between the operands (after applying broadcasting) to be the final type of the operation Note that we do not modify the operand types themselves after applying broadcasting to them. We only use them to calculate the final type """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], Node) t1 = n.args[0].type t2 = n.args[1].type # handle scalar addition if t1 == int and isinstance(t2, TensorType): n.type = t2 return n.type # handle scalar addition elif t2 == int and isinstance(t1, TensorType): n.type = t1 return n.type # we bring the new types to the point where # we can check for consistency # any inconsistency would not have been caused # by broadcasting at this point (new_t1, new_t2) = broadcast_types(t1, t2) if new_t1 != t1 or new_t2 != t2: n.meta['broadcast'] = True n.meta[str(n.args[0])] = new_t1 n.meta[str(n.args[1])] = new_t2 else: n.meta['broadcast'] = False new_t1 = t1 if not n.meta['broadcast'] else new_t1 new_t2 = t2 if not n.meta['broadcast'] else new_t2 # we check for consistency between the new types if is_consistent(new_t1, new_t2): # we return the less precise type because # broadcasting may have happened # for operands with shape [1,2,Dyn] and [1,2,1] # we have to assign the node [1,2,Dyn] if is_more_precise(new_t1, new_t2): n.type = new_t2 else: n.type = new_t1 return n.type else: raise TypeError( f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.' f' Types should match ')
def test_flatten_fully_static(self): annotation_list = [ Dyn, TensorType((2, 5, 6, 9)), TensorType((10, 15, 13, 14)), TensorType((10, Dyn, 13, 14)), TensorType((Dyn, Dyn, Dyn, 10)) ] input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14), (2, 2, 10, 10)] intermediate_list = [ Dyn, (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14), (2, 2, 10, 10) ] start_dim = [1, 2, 1, 2, 0] end_dim = [1, 3, 3, 3, -2] for i in range(5): annotation = annotation_list[i] input = input_list[i] # intermediate_type = intermediate_list[i] class BasicBlock(torch.nn.Module): def __init__(self, start, end): super(BasicBlock, self).__init__() self.start = start self.end = end def forward(self, x): out = torch.flatten(x, self.start, self.end) return out B = BasicBlock(start_dim[i], end_dim[i]) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # annotate our argument for n in graph.nodes: if n.op == 'placeholder': n.type = annotation b = B.forward(torch.rand(input)) tc = GraphTypeChecker({}, traced) tc.type_check() for n in graph.nodes: if n.op == 'output': assert is_consistent(n.type, TensorType(b.size()))
def get_greatest_upper_bound(type1, type2): """ Get the most precise type that's consistent with the given types """ if type1 == Dyn: return type2 elif type2 == Dyn: return type1 elif isinstance(type1, TensorType) and isinstance(type2, TensorType): if not is_consistent(type1, type2): raise TypeError(f'Inconsistent types {type1}, {type2}') gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)] return TensorType(tuple(gub))
def linear_check(tensor_type, module_instance): """ Checks that an input tensor type satisfies the conditions for linear operation and returns the output type based on in and out features given by module_instance """ if len(tensor_type.__args__) >= 2: if is_consistent(module_instance.in_features, tensor_type.__args__[-1]): # Todo backwards propagation new_type_args = list(tensor_type.__args__) new_type_args[-1] = module_instance.out_features return TensorType(tuple(new_type_args)) else: raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}') else: raise TypeError(f'Type {tensor_type} must have rank 2 or more.')
def test_resnet50(self): gm_run = symbolic_trace(resnet50()) sample_input = torch.randn(1, 3, 224, 224) # run our nodes ShapeProp(gm_run).propagate(sample_input) gm_static = symbolic_trace(resnet50()) for n in gm_static.graph.nodes: n.type = None g = GraphTypeChecker({}, gm_static) g.type_check() gm_static.graph.eliminate_dead_code() gm_run.graph.eliminate_dead_code() # here we are checking for consistency with fully dynamic nodes for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes): assert is_consistent(n1.type, TensorType(n2.meta['tensor_meta'].shape)) # here we give the same input as to runtume gm_static_with_types = symbolic_trace(resnet50()) # we initialize our placeholder for n in gm_static_with_types.graph.nodes: if n.op == 'placeholder': n.type = TensorType((1, 3, 224, 224)) g = GraphTypeChecker({}, gm_static_with_types) g.type_check() for n1, n2 in zip(gm_static_with_types.graph.nodes, gm_run.graph.nodes): assert n1.type == TensorType(n2.meta['tensor_meta'].shape) # apply shape inference to graph and check # that the batch size is equal across all layers infer_symbolic_types(gm_static) batch_sizes = set() gm_static.graph.eliminate_dead_code() for n in gm_static.graph.nodes: assert isinstance(n.type, TensorType) batch_sizes.add(n.type.__args__[0]) assert (len(batch_sizes) == 1)
def add_inference_rule(n: Node): assert isinstance(n.args[0], Node) assert isinstance(n.args[1], Node) t1 = n.args[0].type t2 = n.args[1].type # handle scalar addition if t1 == int and isinstance(t2, TensorType): n.type = t2 return n.type elif t2 == int and isinstance(t1, TensorType): n.type = t1 return n.type (new_t1, new_t2) = broadcast_types(t1, t2) if new_t1 != t1 or new_t2 != t2: n.meta['broadcast'] = True n.meta[str(n.args[0])] = new_t1 n.meta[str(n.args[1])] = new_t2 # Todo: maybe figure out that broadcasting definitely did not happen? else: n.meta['broadcast'] = False new_t1 = t1 if not n.meta['broadcast'] else new_t1 new_t2 = t2 if not n.meta['broadcast'] else new_t2 if is_consistent(new_t1, new_t2): # we return the less precise type because # broadcasting may have happened # for operands with shape [1,2,Dyn] and [1,2,1] # we have to assign the node [1,2,Dyn] if is_more_precise(new_t1, new_t2): n.type = new_t2 else: n.type = new_t1 return n.type else: raise TypeError( f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.' f' Types should match ')
def get_greatest_upper_bound(type1, type2): """ Get the most precise type that's consistent with the given types """ if type1 == Dyn: return type2 elif type2 == Dyn: return type1 elif isinstance(type1, TensorType) and isinstance(type2, TensorType): assert is_consistent(type1, type2) gub = [ t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__) ] return TensorType(tuple(gub)) else: raise NotImplementedError( f'Greatest upper bound not yet implemented for these types {type1}, {type2}' )
def test_type_check_conv2D_2_fully_static(self): annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)] input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14), (1, 2, 2, 3)] intermediate_types = [(1, Dyn, Dyn, 7), (2, Dyn, 4, 6), (10, 15, Dyn, 5), (10, 15, 7, 7), (1, Dyn, Dyn, Dyn)] in_planes_list = [2, 5, 15, 15, 2] stride_list = [1, 2, 3, 2, 2] out_planes_list = [2, 5, 15, 15, 2] groups_list = [1, 5, 5, 5, 2] dilation_list = [1, 2, 3, 3, 3] padding_list = [1, 2, 3, 3, 3] kernel_size_list = [1, 2, 3, 3, 3] output_types = [(1, 2, Dyn, 7), (2, 5, 4, 6), (10, 15, Dyn, 5), (10, 15, 7, 7), (1, 2, Dyn, Dyn)] for i in range(5): annotation = annotation_list[i] input = input_list[i] in_planes = in_planes_list[i] stride = stride_list[i] out_planes = out_planes_list[i] groups = groups_list[i] dilation = dilation_list[i] padding = padding_list[i] kernel_size = kernel_size_list[i] intermediate_type = intermediate_types[i] class BasicBlock(torch.nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): super(BasicBlock, self).__init__() self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False, dilation=dilation) def forward(self, x): out = self.conv1(x) return out B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # annotate our argument for n in graph.nodes: if n.op == 'placeholder': n.type = TensorType(annotation) b = B.forward(torch.rand(input)) tc = GraphTypeChecker({}, traced) tc.type_check() for n in graph.nodes: if n.op == 'output': assert is_consistent(n.type, TensorType(b.size())) # test with intermediate annotations class BasicBlock(torch.nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): super(BasicBlock, self).__init__() self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False, dilation=dilation) def forward(self, x): out = self.conv1(x) return out B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # populate our intermediate notes for n in traced.graph.nodes: if n.op == 'call_module': n.type = TensorType(intermediate_type) tc = GraphTypeChecker({}, traced) tc.type_check() for n in traced.graph.nodes: if n.op == 'output': assert n.type == TensorType(output_types[i]) assert is_consistent(n.type, TensorType(b.size()))
def test_type_maxpool2d_fully_static(self): annotation_list = [(Dyn, Dyn, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 10)] input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14), (2, 2, 10, 10)] intermediate_types = [(1, 2, Dyn, Dyn), (2, Dyn, 2, 4), (10, 15, Dyn, 2), (10, 15, 2, 3), (2, Dyn, Dyn, Dyn)] stride_list = [1, 2, 3, 2, 1] dilation_list = [1, 2, 3, 3, 2] padding_list = [1, 2, 3, 3, 1] kernel_size_list = [2, 4, 6, 6, 3] output_types = [(1, 2, 4, 6), (2, 5, 2, 4), (10, 15, 2, 2), (10, 15, 2, 3), (2, Dyn, Dyn, 8)] for i in range(5): annotation = annotation_list[i] input = input_list[i] stride = stride_list[i] dilation = dilation_list[i] padding = padding_list[i] kernel_size = kernel_size_list[i] intermediate_type = intermediate_types[i] class BasicBlock(torch.nn.Module): def __init__(self, kernel_size, stride, padding, dilation): super(BasicBlock, self).__init__() self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=False, ceil_mode=False) def forward(self, x): out = self.pool(x) return out B = BasicBlock(kernel_size, stride, padding, dilation) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # annotate our argument for n in graph.nodes: if n.op == 'placeholder': n.type = TensorType(annotation) b = B.forward(torch.rand(input)) tc = GraphTypeChecker({}, traced) tc.type_check() for n in graph.nodes: if n.op == 'output': assert is_consistent(n.type, TensorType(b.size())) # test with intermediate annotations class BasicBlock(torch.nn.Module): def __init__(self, kernel_size, stride, padding, dilation): super(BasicBlock, self).__init__() self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=False, ceil_mode=False) def forward(self, x): out = self.pool(x) return out B = BasicBlock(kernel_size, stride, padding, dilation) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # annotate our argument for n in graph.nodes: if n.op == 'placeholder': n.type = TensorType(annotation) # populate our intermediate notes for n in traced.graph.nodes: if n.op == 'call_module': n.type = TensorType(intermediate_type) tc = GraphTypeChecker({}, traced) tc.type_check() for n in traced.graph.nodes: if n.op == 'output': assert n.type == TensorType(output_types[i]) assert is_consistent(n.type, TensorType(b.size()))