def test_symbolic_add_with_broadcast(self): class M(torch.nn.Module): def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))): return torch.add(x, y) module = M() symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) tc = GraphTypeChecker({}, symbolic_traced) tc.type_check() infer_symbolic_types(symbolic_traced) r = Refine(symbolic_traced) r.refine() assert r.constraints == [ Equality(1, 1), Equality(2, 2), Equality(3, 3) ] # note that there is no equality constraint between dyn and 4 because # dyn could be 4 or 1 infer_symbolic_types(symbolic_traced) expected_ph_types = [ TensorType((1, 2, 3, sympy.symbols('~0'))), TensorType((2, 3, 4)), TensorType((1, 2, 3, sympy.symbols('~1'))), TensorType((1, 2, 3, sympy.symbols('~1'))) ] expected_iter = iter(expected_ph_types) for n in symbolic_traced.graph.nodes: assert n.type == next(expected_iter)
def test_type_check_batch_norm_symbolic(self): class BasicBlock(torch.nn.Module): def __init__(self, inplanes, planes): super(BasicBlock, self).__init__() norm_layer = torch.nn.BatchNorm2d self.bn1 = norm_layer(planes) def forward(self, x: Dyn): identity = x out: TensorType((2, 2, Dyn, 4)) = self.bn1(x) out += identity return out B = BasicBlock(2, 2) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) tc.type_check() infer_symbolic_types(traced) my_types = iter([ TensorType[(2, 2, sympy.symbols('~7'), 4)], TensorType[(2, 2, sympy.symbols('~7'), 4)], TensorType[(2, 2, sympy.symbols('~7'), 4)], TensorType[(2, 2, sympy.symbols('~7'), 4)] ]) for n in graph.nodes: assert n.type == next(my_types)
def test_type_check_conv2D_types(self): class BasicBlock(torch.nn.Module): def __init__(self, inplanes, planes, stride=1): super(BasicBlock, self).__init__() norm_layer = torch.nn.BatchNorm2d self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) def forward(self, x: Dyn): identity = x out: TensorType((2, 2, Dyn, 4)) = self.conv1(x) out += identity return out B = BasicBlock(2, 2) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) tc.type_check() infer_symbolic_types(traced) for n in traced.graph.nodes: if n.op == 'call_module': assert isinstance(n.type.__args__[2], sympy.floor) assert isinstance(n.type.__args__[3], sympy.floor)
def test_resnet50(self): gm_run = symbolic_trace(resnet50()) sample_input = torch.randn(1, 3, 224, 224) # run our nodes ShapeProp(gm_run).propagate(sample_input) gm_static = symbolic_trace(resnet50()) for n in gm_static.graph.nodes: n.type = None g = GraphTypeChecker({}, gm_static) g.type_check() gm_static.graph.eliminate_dead_code() gm_run.graph.eliminate_dead_code() # here we are checking for consistency with fully dynamic nodes for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes): assert is_consistent(n1.type, TensorType(n2.meta['tensor_meta'].shape)) # here we give the same input as to runtume gm_static_with_types = symbolic_trace(resnet50()) # we initialize our placeholder for n in gm_static_with_types.graph.nodes: if n.op == 'placeholder': n.type = TensorType((1, 3, 224, 224)) g = GraphTypeChecker({}, gm_static_with_types) g.type_check() for n1, n2 in zip(gm_static_with_types.graph.nodes, gm_run.graph.nodes): assert n1.type == TensorType(n2.meta['tensor_meta'].shape) # apply shape inference to graph and check # that the batch size is equal across all layers infer_symbolic_types(gm_static) batch_sizes = set() gm_static.graph.eliminate_dead_code() for n in gm_static.graph.nodes: assert isinstance(n.type, TensorType) batch_sizes.add(n.type.__args__[0]) assert (len(batch_sizes) == 1)
def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self): class BasicBlock(torch.nn.Module): def __init__(self): super(BasicBlock, self).__init__() self.conv1 = torch.nn.Conv2d(3, 6, 5) self.pool = torch.nn.MaxPool2d(2, 2) self.conv2 = torch.nn.Conv2d(6, 16, 5) self.fc1 = torch.nn.Linear(5, 120) self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) def forward(self, x: TensorType((4, 3, Dyn, Dyn))): out = self.conv1(x) out = self.pool(out) out = self.conv2(out) out = self.pool(out) out = self.fc1(out) out = self.pool2(out) out = torch.flatten(out, 1) return out B = BasicBlock() ast_rewriter = RewritingTracer() traced = symbolic_trace(B) tc = GraphTypeChecker({}, traced) tc.type_check() infer_symbolic_types(traced) for n in traced.graph.nodes: if n.target == 'conv1': assert n.type == TensorType( (4, 6, sympy.floor((sympy.symbols('~0') - 4)), sympy.floor((sympy.symbols('~1') - 4)))) elif n.target == 'conv2': assert n.type == TensorType( (4, 16, sympy.floor((sympy.symbols('~4') - 4)), sympy.floor((sympy.symbols('~5') - 4))))
def test_symbolic_add_with_broadcast_2(self): class M(torch.nn.Module): def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))): return torch.add(x, y) module = M() symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) tc = GraphTypeChecker({}, symbolic_traced) tc.type_check() infer_symbolic_types(symbolic_traced) r = Refine(symbolic_traced) r.refine() expected_ph_types = [ TensorType((1, 2)), TensorType((sympy.symbols('~1'), 2)), TensorType((sympy.symbols('~1'), 2)), TensorType((sympy.symbols('~1'), 2)) ] expected_iter = iter(expected_ph_types) for n in symbolic_traced.graph.nodes: assert n.type == next(expected_iter)