def setUp(self): super().setUp() self.graph, self.constants, _ = cnn.CifarNet() self.model = Model(self.graph, self.constants) self.state = self.model.init(random.PRNGKey(0), jnp.ones( (1, 32, 32, 3))) self.subgraph = [ subgraph.SubgraphNode(op=new_op( op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), ), subgraph.SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv/1" ]), output_names=["conv_layer1/relu"]) ] self.new_graph = subgraph.replace_subgraph(self.graph, self.subgraph) self.new_model = Model(self.new_graph, self.constants) self.new_state = self.new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
def test_synthesizer_easy_one(self): """Replacing [conv3x3(features = 64)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp])
def test_synthesizer_two(self): """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp])
def test_synthesizer_hard(self): if not self.hard: return subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp])
def test_multi_input_output(self): """Tests a subgraph substitution on a graph with multiple inputs / output ops. We use a ResNet model, which has skip connections. This test checks that the substitution produces the expected number of ops, and also that the newly produced graph is still executable. """ graph, constants, _ = resnetv1.ResNet18(num_classes=10, input_resolution="small") model = Model(graph, constants) state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) y = model.apply(state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10)) subg = [ subgraph.SubgraphNode( op=new_op(op_name="subgraph/conv0", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["resnet11/skip/relu1"])), subgraph.SubgraphNode( op=new_op(op_name="subgraph/gelu1", op_type=OpType.GELU, input_names=["subgraph/conv0"]), output_names=["resnet_stride1_filtermul1_basic12/relu2"]) ] new_graph = subgraph.replace_subgraph(graph, subg) # the subgraph is 2 ops (conv / gelu) replacing 3 ops (conv / bn / relu) self.assertLen(graph.ops, len(new_graph.ops) + 1) new_model = Model(new_graph, constants) new_state = new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) y = new_model.apply(new_state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10))
def test_synthesizer_easy_one(self): """Replacing [conv3x3(features = 64)]. Because we do not test linear, this is replaced by dense3x3(features = 64) due to the enumeration order. """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp])
def test_synthesizer_two(self): """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)]. Because we do not check for the depth property, [dense(features = 64), avgpool2x2(strides=2x2)] works as well (which is what is synthesized due to the enumeration order). """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp])
def test_synthesizer_easy_two(self): """Replacing [conv3x3(features = 64)]. Because we test all three props, this is replaced by conv3x3(features = 64) (i.e., an identical op) due to the enumeration order. """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp])
def test_synthesizer_resnet_big(self): self.graph, self.constants, _ = resnetv1.ResNet18( num_classes=10, input_resolution="small") self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)[self.graph.output_names[0]] self.max_size = int(10e8) subg_ops = self.graph.ops[3:5] + self.graph.ops[8:12] subg = [subgraph.SubgraphNode(op=o) for o in subg_ops] subg[-1].output_names = [f"{subg[-1].op.name}:0"] subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp])
def make_subgraph_spec( self, subg, adjust_features = False): """Converts a list of ops into a (sequential) subgraph spec.""" input_name = self.input_name subgraph_spec = [] # Check to see if need to unique-ify ops. op_names_unique = True idxs = [] for op in subg: splits = op.name.split("/") idx = splits[-1] if len(splits) == 1 or idx in idxs or not idx.isdigit(): op_names_unique = False break idxs.append(idx) for idx, op in enumerate(subg): op = copy.deepcopy(op) if not op_names_unique: op.name = f"{op.name}/{idx}" op.input_names = [input_name] input_name = op.name + ":0" subgraph_spec.append(subgraph.SubgraphNode(op)) subgraph_spec[-1].output_names = [self.output_name] if adjust_features: output_features = self.get_output_features(subg) if output_features: for node in subgraph_spec[::-1]: if node.op.type == OpType.DENSE or node.op.type == OpType.CONV: node.op.op_kwargs["features"] = output_features break return subgraph_spec