def setUp(self): super().setUp() self.graph_dense = new_graph( ["input"], ["output"], [ new_op( op_name="output", op_type=OpType.SOFTMAX, # op_kwargs={"features": 10}, input_names=["input"]) ]) state_dense = Model(self.graph_dense).init( random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))}) self.subgraph_dense = SubgraphModel(self.graph_dense, None, state_dense, {"input": jnp.ones((5, 5, 5))}) self.lp_dense = linear.LinopProperty().infer(self.subgraph_dense) self.graph_conv = new_graph(["input"], ["output"], [ new_op(op_name="output", op_type=OpType.CONV, op_kwargs={ "features": 10, "kernel_size": [3, 3] }, input_names=["input"]) ]) state_conv = Model(self.graph_conv).init( random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))}) self.subgraph_conv = SubgraphModel(self.graph_conv, None, state_conv, {"input": jnp.ones((5, 5, 5))}) self.lp_conv = linear.LinopProperty().infer(self.subgraph_conv)
def test_abstract(self): graphs = [] conv_op = functools.partial(new_op, op_type=OpType.CONV, op_kwargs={ "features": 10, "kernel_size": [3, 3] }) dense_op = functools.partial(new_op, op_type=OpType.DENSE, op_kwargs={ "features": 10, }) for op_type in [ OpType.RELU, OpType.SOFTMAX, OpType.LAYER_NORM, OpType.BATCH_NORM ]: for op_ctr in [conv_op, dense_op]: graphs.append([ op_ctr(input_names=["input"], op_name="other"), new_op(op_name="output", op_type=op_type, input_names=["other"]) ]) graphs.append([ new_op(op_name="other", op_type=op_type, input_names=["input"]), op_ctr(input_names=["other"], op_name="output"), ]) input_tensor = {"input": jnp.ones((5, 5, 5, 5))} for graph in graphs: graph = new_graph(["input"], ["output"], graph) state = Model(graph).init(random.PRNGKey(1), input_tensor) # Make all the kernels positive, otherwise, the ReLU might zero out the # entire tensor. state = jax.tree_util.tree_map(abs, state) subg_model = SubgraphModel(graph, None, state, input_tensor) lp_abstract = linear.LinopProperty().infer(subg_model, abstract=True) lp_concrete = linear.LinopProperty().infer(subg_model, abstract=False) pairings_concerete = lp_concrete.pairings["output"][ "input"].mappings pairings_abstract = lp_abstract.pairings["output"][ "input"].mappings print("concrete:", pairings_concerete) print("abstract:", pairings_abstract) self.assertTrue( ((pairings_abstract - pairings_concerete) == 0).all())
def test_synthesizer_two(self): """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp])
def test_synthesizer_easy_two(self): """Replacing [conv3x3(features = 64)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp])
def test_synthesizer_hard(self): if not self.hard: return subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp])
def test_synthesizer_two(self): """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)]. Because we do not check for the depth property, [dense(features = 64), avgpool2x2(strides=2x2)] works as well (which is what is synthesized due to the enumeration order). """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp])
def test_synthesizer_easy_two(self): """Replacing [conv3x3(features = 64)]. Because we test all three props, this is replaced by conv3x3(features = 64) (i.e., an identical op) due to the enumeration order. """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp])
def test_synthesizer_resnet_big(self): self.graph, self.constants, _ = resnetv1.ResNet18( num_classes=10, input_resolution="small") self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)[self.graph.output_names[0]] self.max_size = int(10e8) subg_ops = self.graph.ops[3:5] + self.graph.ops[8:12] subg = [subgraph.SubgraphNode(op=o) for o in subg_ops] subg[-1].output_names = [f"{subg[-1].op.name}:0"] subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp])
def test_multi_input(self): ops = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu0", op_type=OpType.RELU, input_names=["dense0"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu1", op_type=OpType.RELU, input_names=["dense1"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu2", op_type=OpType.RELU, input_names=["dense2"]), new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), ] graph = new_graph(input_names=["input"], output_names=["add0", "add1"], ops=ops) subgraph_spec = [ SubgraphNode(op=new_op( op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])), SubgraphNode(op=new_op( op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])), SubgraphNode(op=new_op( op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])), SubgraphNode(op=new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), output_names=["add0"]), SubgraphNode(op=new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), output_names=["add1"]), ] replaced_graph = replace_subgraph(graph, subgraph_spec) inp = {"input": jnp.ones((10, 32, 32, 3))} subgraph_model = SubgraphModel(replaced_graph, {}, {}, inp, subgraph_spec) lp = linear.LinopProperty().infer(subgraph_model) pairings = lp.pairings self.assertLen(pairings, 2) self.assertIn("add0:0", pairings) self.assertLen(pairings["add0:0"], 2) self.assertIn("dense0:0", pairings["add0:0"]) self.assertIn("dense1:0", pairings["add0:0"]) self.assertIn("add1:0", pairings) self.assertLen(pairings["add1:0"], 2) self.assertIn("dense1:0", pairings["add1:0"]) self.assertIn("dense2:0", pairings["add1:0"])
def test_full(self): graph, constants, _ = cnn.CifarNet() subgraph_model = SubgraphModel(graph, constants, None, {"input": jnp.ones((10, 32, 32, 3))}) linear.LinopProperty().infer(subgraph_model)