def test_type_check_conv2D_types(self): class BasicBlock(torch.nn.Module): def __init__(self, inplanes, planes, stride=1): super(BasicBlock, self).__init__() norm_layer = torch.nn.BatchNorm2d self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) def forward(self, x: Dyn): identity = x out: TensorType((2, 2, Dyn, 4)) = self.conv1(x) out += identity return out B = BasicBlock(2, 2) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) tc.type_check() infer_symbolic_types(traced) for n in traced.graph.nodes: if n.op == 'call_module': assert isinstance(n.type.__args__[2], sympy.floor) assert isinstance(n.type.__args__[3], sympy.floor)
def test_type_check_conv2D(self): class BasicBlock(torch.nn.Module): def __init__(self, inplanes, planes, stride=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = torch.nn.BatchNorm2d self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) def forward(self, x: Dyn): identity = x out: TensorType((2, 2, Dyn, 4)) = self.conv1(x) out += identity return out B = BasicBlock(2, 2) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) tc.type_check() for n in graph.nodes: if n.op == 'placeholder': assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) if n.op == 'call_function': assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) if n.op == 'output': assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) if n.op == 'call_module': assert n.type == TensorType((2, 2, Dyn, 4))
def test_type_check_batch_norm_symbolic(self): class BasicBlock(torch.nn.Module): def __init__(self, inplanes, planes): super(BasicBlock, self).__init__() norm_layer = torch.nn.BatchNorm2d self.bn1 = norm_layer(planes) def forward(self, x: Dyn): identity = x out: TensorType((2, 2, Dyn, 4)) = self.bn1(x) out += identity return out B = BasicBlock(2, 2) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) tc.type_check() infer_symbolic_types(traced) my_types = iter([ TensorType[(2, 2, sympy.symbols('~7'), 4)], TensorType[(2, 2, sympy.symbols('~7'), 4)], TensorType[(2, 2, sympy.symbols('~7'), 4)], TensorType[(2, 2, sympy.symbols('~7'), 4)] ]) for n in graph.nodes: assert n.type == next(my_types)
def test_conv_reshape_add(self): class BasicBlock(torch.nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): super(BasicBlock, self).__init__() self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False, dilation=dilation) def forward(self, x: Dyn, y: Dyn): return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) B = BasicBlock(2, 2, 2, 3, 2, 2, 2) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") generator = ConstraintGenerator(traced) new_constraints, counter = generator.generate_constraints(0) assert len(new_constraints.conjucts) == 16
def test_flatten_fully_static(self): annotation_list = [ Dyn, TensorType((2, 5, 6, 9)), TensorType((10, 15, 13, 14)), TensorType((10, Dyn, 13, 14)), TensorType((Dyn, Dyn, Dyn, 10)) ] input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14), (2, 2, 10, 10)] intermediate_list = [ Dyn, (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14), (2, 2, 10, 10) ] start_dim = [1, 2, 1, 2, 0] end_dim = [1, 3, 3, 3, -2] for i in range(5): annotation = annotation_list[i] input = input_list[i] # intermediate_type = intermediate_list[i] class BasicBlock(torch.nn.Module): def __init__(self, start, end): super(BasicBlock, self).__init__() self.start = start self.end = end def forward(self, x): out = torch.flatten(x, self.start, self.end) return out B = BasicBlock(start_dim[i], end_dim[i]) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # annotate our argument for n in graph.nodes: if n.op == 'placeholder': n.type = annotation b = B.forward(torch.rand(input)) tc = GraphTypeChecker({}, traced) tc.type_check() for n in graph.nodes: if n.op == 'output': assert is_consistent(n.type, TensorType(b.size()))
def test_add_reshape(self): class BasicBlock(torch.nn.Module): def __init__(self): super(BasicBlock, self).__init__() def forward(self, x: Dyn, y: Dyn): return torch.add(torch.reshape(x, (1, 2)), torch.reshape(y, (2, 2))) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(BasicBlock()) traced = GraphModule(ast_rewriter.root, graph, "gm") generator = ConstraintGenerator(traced) new_constraints, counter = generator.generate_constraints(0) assert len(new_constraints.conjucts) == 11
def test_type_check_conv2D_maxpool2d_flatten(self): class BasicBlock(torch.nn.Module): def __init__(self): super(BasicBlock, self).__init__() self.conv1 = torch.nn.Conv2d(3, 6, 5) self.pool = torch.nn.MaxPool2d(2, 2) self.conv2 = torch.nn.Conv2d(6, 16, 5) self.fc1 = torch.nn.Linear(5, 120) self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) def forward(self, x: TensorType((4, 3, 32, 32))): out = self.conv1(x) out = self.pool(out) out = self.conv2(out) out = self.pool(out) out = self.fc1(out) out = self.pool2(out) out = torch.flatten(out, 1) return out B = BasicBlock() ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) tc.type_check() expected_ph_types = [ TensorType((4, 3, 32, 32)), TensorType((4, 6, 28, 28)), TensorType((4, 6, 14, 14)), TensorType((4, 16, 10, 10)), TensorType((4, 16, 5, 5)), TensorType((4, 16, 5, 120)), TensorType((4, 16, 6, 7)), TensorType((4, 672)), TensorType((4, 672)) ] expected_iter = iter(expected_ph_types) traced.graph.eliminate_dead_code() for n in traced.graph.nodes: assert n.type == next(expected_iter)
def test_subgraph_rewriter_annotations_int(self): class M1(torch.nn.Module): def forward(self, x): y: int = x return torch.add(x, y) class M2(torch.nn.Module): def forward(self, x): y = annotate(x, int) return torch.add(x, y) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(M1()) module = M2() symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) for n, m in zip(symbolic_traced.graph.nodes, graph.nodes): if n.op == 'placeholder': assert n.type == int assert m.type == int
def test_type_typechecl_maxpool2d_3dinput(self): class BasicBlock(torch.nn.Module): def __init__(self): super(BasicBlock, self).__init__() self.pool = torch.nn.MaxPool2d(5, 8) def forward(self, x: TensorType((64, 8, 8))): out = self.pool(x) return out B = BasicBlock() ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) tc.type_check() for n in traced.graph.nodes: if n.target == 'output': assert n.type == TensorType((64, 1, 1))
def test_type_check_batch_norm_2D_false(self): class BasicBlock(torch.nn.Module): def __init__(self, inplanes, planes): super(BasicBlock, self).__init__() norm_layer = torch.nn.BatchNorm2d self.bn1 = norm_layer(planes) def forward(self, x: TensorType((2, 2, 5))): identity = x out: TensorType((2, 2, Dyn, 4)) = self.bn1(x) out += identity return out B = BasicBlock(2, 2) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) with self.assertRaises(TypeError): tc.type_check()
def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self): class BasicBlock(torch.nn.Module): def __init__(self): super(BasicBlock, self).__init__() self.conv1 = torch.nn.Conv2d(3, 6, 5) self.pool = torch.nn.MaxPool2d(2, 2) self.conv2 = torch.nn.Conv2d(6, 16, 5) self.fc1 = torch.nn.Linear(5, 120) self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) def forward(self, x: TensorType((4, 3, Dyn, Dyn))): out = self.conv1(x) out = self.pool(out) out = self.conv2(out) out = self.pool(out) out = self.fc1(out) out = self.pool2(out) out = torch.flatten(out, 1) return out B = BasicBlock() ast_rewriter = RewritingTracer() traced = symbolic_trace(B) tc = GraphTypeChecker({}, traced) tc.type_check() infer_symbolic_types(traced) for n in traced.graph.nodes: if n.target == 'conv1': assert n.type == TensorType( (4, 6, sympy.floor((sympy.symbols('~0') - 4)), sympy.floor((sympy.symbols('~1') - 4)))) elif n.target == 'conv2': assert n.type == TensorType( (4, 16, sympy.floor((sympy.symbols('~4') - 4)), sympy.floor((sympy.symbols('~5') - 4))))
def test_type_check_conv2D_2(self): class BasicBlock(torch.nn.Module): def __init__(self, inplanes, planes, stride=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = torch.nn.BatchNorm2d self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) def forward(self, x: TensorType((5, 2, 3, 4))): identity = x out = self.conv1(x) out += identity return out B = BasicBlock(2, 2) b = B.forward(torch.rand(5, 2, 3, 4)) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) tc.type_check() t = TensorType((5, 2, 3, 4)) for n in graph.nodes: if n.op == 'placeholder': assert n.type == t if n.op == 'call_function': assert n.type == t if n.op == 'output': assert torch.Size(n.type.__args__) == b.shape if n.op == 'call_module': assert n.type == t B = BasicBlock(1, 2) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) with self.assertRaises(TypeError): tc.type_check()
def test_type_check_batch_norm_2D_broadcast(self): class BasicBlock(torch.nn.Module): def __init__(self, inplanes, planes): super(BasicBlock, self).__init__() norm_layer = torch.nn.BatchNorm2d self.bn1 = norm_layer(planes) def forward(self, x: Dyn): identity = x out: TensorType((2, 2, Dyn, 4)) = self.bn1(x) out += identity return out B = BasicBlock(2, 2) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) tc.type_check() for n in graph.nodes: if n.op == 'placeholder': assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) if n.op == 'call_function': assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) if n.op == 'output': assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) if n.op == 'call_module': assert n.type == TensorType((2, 2, Dyn, 4)) B = BasicBlock(1, 1) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) with self.assertRaises(TypeError): tc.type_check()
def test_type_check_conv2D_2_fully_static(self): annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)] input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14), (1, 2, 2, 3)] intermediate_types = [(1, Dyn, Dyn, 7), (2, Dyn, 4, 6), (10, 15, Dyn, 5), (10, 15, 7, 7), (1, Dyn, Dyn, Dyn)] in_planes_list = [2, 5, 15, 15, 2] stride_list = [1, 2, 3, 2, 2] out_planes_list = [2, 5, 15, 15, 2] groups_list = [1, 5, 5, 5, 2] dilation_list = [1, 2, 3, 3, 3] padding_list = [1, 2, 3, 3, 3] kernel_size_list = [1, 2, 3, 3, 3] output_types = [(1, 2, Dyn, 7), (2, 5, 4, 6), (10, 15, Dyn, 5), (10, 15, 7, 7), (1, 2, Dyn, Dyn)] for i in range(5): annotation = annotation_list[i] input = input_list[i] in_planes = in_planes_list[i] stride = stride_list[i] out_planes = out_planes_list[i] groups = groups_list[i] dilation = dilation_list[i] padding = padding_list[i] kernel_size = kernel_size_list[i] intermediate_type = intermediate_types[i] class BasicBlock(torch.nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): super(BasicBlock, self).__init__() self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False, dilation=dilation) def forward(self, x): out = self.conv1(x) return out B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # annotate our argument for n in graph.nodes: if n.op == 'placeholder': n.type = TensorType(annotation) b = B.forward(torch.rand(input)) tc = GraphTypeChecker({}, traced) tc.type_check() for n in graph.nodes: if n.op == 'output': assert is_consistent(n.type, TensorType(b.size())) # test with intermediate annotations class BasicBlock(torch.nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): super(BasicBlock, self).__init__() self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False, dilation=dilation) def forward(self, x): out = self.conv1(x) return out B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # populate our intermediate notes for n in traced.graph.nodes: if n.op == 'call_module': n.type = TensorType(intermediate_type) tc = GraphTypeChecker({}, traced) tc.type_check() for n in traced.graph.nodes: if n.op == 'output': assert n.type == TensorType(output_types[i]) assert is_consistent(n.type, TensorType(b.size()))
def test_typecheck_basicblock(self): class BasicBlock(torch.nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = torch.nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError( 'BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError( "Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = torch.nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x: TensorType((2, 2, 4, 5))): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out B = BasicBlock(2, 2) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) tc.type_check() for n in traced.graph.nodes: if n.target == 'output': assert isinstance(n.type, TensorType) assert torch.Size(n.type.__args__) == B.forward( torch.rand(2, 2, 4, 5)).size()
def symbolic_trace_with_rewrite( root: Union[torch.nn.Module, Callable]) -> GraphModule: return GraphModule( root if isinstance(root, torch.nn.Module) else torch.nn.Module(), RewritingTracer().trace(root), )
def test_type_maxpool2d_fully_static(self): annotation_list = [(Dyn, Dyn, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 10)] input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14), (2, 2, 10, 10)] intermediate_types = [(1, 2, Dyn, Dyn), (2, Dyn, 2, 4), (10, 15, Dyn, 2), (10, 15, 2, 3), (2, Dyn, Dyn, Dyn)] stride_list = [1, 2, 3, 2, 1] dilation_list = [1, 2, 3, 3, 2] padding_list = [1, 2, 3, 3, 1] kernel_size_list = [2, 4, 6, 6, 3] output_types = [(1, 2, 4, 6), (2, 5, 2, 4), (10, 15, 2, 2), (10, 15, 2, 3), (2, Dyn, Dyn, 8)] for i in range(5): annotation = annotation_list[i] input = input_list[i] stride = stride_list[i] dilation = dilation_list[i] padding = padding_list[i] kernel_size = kernel_size_list[i] intermediate_type = intermediate_types[i] class BasicBlock(torch.nn.Module): def __init__(self, kernel_size, stride, padding, dilation): super(BasicBlock, self).__init__() self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=False, ceil_mode=False) def forward(self, x): out = self.pool(x) return out B = BasicBlock(kernel_size, stride, padding, dilation) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # annotate our argument for n in graph.nodes: if n.op == 'placeholder': n.type = TensorType(annotation) b = B.forward(torch.rand(input)) tc = GraphTypeChecker({}, traced) tc.type_check() for n in graph.nodes: if n.op == 'output': assert is_consistent(n.type, TensorType(b.size())) # test with intermediate annotations class BasicBlock(torch.nn.Module): def __init__(self, kernel_size, stride, padding, dilation): super(BasicBlock, self).__init__() self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=False, ceil_mode=False) def forward(self, x): out = self.pool(x) return out B = BasicBlock(kernel_size, stride, padding, dilation) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # annotate our argument for n in graph.nodes: if n.op == 'placeholder': n.type = TensorType(annotation) # populate our intermediate notes for n in traced.graph.nodes: if n.op == 'call_module': n.type = TensorType(intermediate_type) tc = GraphTypeChecker({}, traced) tc.type_check() for n in traced.graph.nodes: if n.op == 'output': assert n.type == TensorType(output_types[i]) assert is_consistent(n.type, TensorType(b.size()))