def setUp(self): super().setUp() self.graph_dense = new_graph( ["input"], ["output"], [ new_op( op_name="output", op_type=OpType.SOFTMAX, # op_kwargs={"features": 10}, input_names=["input"]) ]) state_dense = Model(self.graph_dense).init( random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))}) self.subgraph_dense = SubgraphModel(self.graph_dense, None, state_dense, {"input": jnp.ones((5, 5, 5))}) self.lp_dense = linear.LinopProperty().infer(self.subgraph_dense) self.graph_conv = new_graph(["input"], ["output"], [ new_op(op_name="output", op_type=OpType.CONV, op_kwargs={ "features": 10, "kernel_size": [3, 3] }, input_names=["input"]) ]) state_conv = Model(self.graph_conv).init( random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))}) self.subgraph_conv = SubgraphModel(self.graph_conv, None, state_conv, {"input": jnp.ones((5, 5, 5))}) self.lp_conv = linear.LinopProperty().infer(self.subgraph_conv)
def test_abstract_sequential_synthesizer_output_features(self): graph, constants, _ = cnn.CifarNet() subgraph_spec = [ SubgraphNode( op=new_op( op_name="conv_layer1/conv", op_type=OpType.CONV, op_kwargs={ "features": "S:-1*2", "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]),), SubgraphNode( op=new_op( op_name="conv_layer1/relu", op_type=OpType.RELU, input_names=["conv_layer1/conv"]), output_names=["conv_layer1/relu"]) ] subgraph = replace_subgraph(graph, subgraph_spec) subgraph_model = SubgraphModel(subgraph, constants, None, {"input": jnp.zeros((5, 32, 32, 10))}, subgraph_spec) sp = shape.ShapeProperty().infer(subgraph_model) syn = TestSequentialSynthesizer([(subgraph_model, [sp])], 0) self.assertEqual(syn.output_features_mul, 2) self.assertEqual(syn.output_features_div, 1)
def test_abstract_sequential_synthesizer_fail(self): graph, constants, _ = cnn.CifarNet() subgraph_spec = [ SubgraphNode( op=new_op( op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), output_names=["conv_layer1/conv"]), SubgraphNode( op=new_op( op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv"]), output_names=["conv_layer1/relu"]) ] subgraph = SubgraphModel(graph, constants, None, {"input": jnp.zeros((5, 32, 32, 10))}, subgraph_spec) self.assertRaisesRegex(ValueError, ".*exactly one input.*", TestSequentialSynthesizer, [(subgraph, [])], 0)
def test_rewire(self): # orig: conv, relu, pool, conv, relu, pool, flatten, dense, relu, dense # new: conv, relu, pool, conv, gelu, pool, flatten, dense, relu, dense subgraph_spec = [ SubgraphNode(op=new_op(op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), ), SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv/1"]), output_names=["conv_layer1/relu"]) ] graph = replace_subgraph(self.graph, subgraph_spec) subgraph_model = SubgraphModel(graph, self.constants, {}, {}, subgraph_spec) dp = depth.DepthProperty().infer(subgraph_model) depth_map = dp.depth_map self.assertLen(depth_map, 1) self.assertIn("conv_layer0/avg_pool:0", depth_map) self.assertLen(depth_map["conv_layer0/avg_pool:0"], 2) self.assertIn("conv_layer1/relu:0", depth_map["conv_layer0/avg_pool:0"]) self.assertEqual( depth_map["conv_layer0/avg_pool:0"]["conv_layer1/relu:0"], 1) self.assertIn("conv_layer1/gelu/1:0", depth_map["conv_layer0/avg_pool:0"]) self.assertEqual( depth_map["conv_layer0/avg_pool:0"]["conv_layer1/gelu/1:0"], 1)
def test_rewire(self): subgraph_spec = [ SubgraphNode(op=new_op(op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), ), SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv/1"]), output_names=["conv_layer1/relu"]) ] graph = replace_subgraph(self.graph, subgraph_spec) state = Model(graph, self.constants).init(random.PRNGKey(0), {"input": jnp.ones((5, 32, 32, 3))}) subgraph_model = SubgraphModel(graph, self.constants, state, {"input": jnp.ones( (5, 32, 32, 3))}, subgraph_spec) sp = shape.ShapeProperty().infer(subgraph_model) self.assertLen(sp.input_shapes, 1) self.assertIn("conv_layer0/avg_pool:0", sp.input_shapes) self.assertLen(sp.output_shapes, 2) self.assertIn("conv_layer1/gelu/1:0", sp.output_shapes) self.assertIn("conv_layer1/relu:0", sp.output_shapes)
def test_unsatisfy(self): # This test removes the last dense layer, so the new graph should be less # deep. graph = new_graph(input_names=["input"], output_names=["fc/relu"], ops=self.graph.ops) subgraph_model = SubgraphModel(graph, self.constants, {}, {}) self.assertFalse(self.dp.verify(subgraph_model))
def test_abstract(self): graphs = [] conv_op = functools.partial(new_op, op_type=OpType.CONV, op_kwargs={ "features": 10, "kernel_size": [3, 3] }) dense_op = functools.partial(new_op, op_type=OpType.DENSE, op_kwargs={ "features": 10, }) for op_type in [ OpType.RELU, OpType.SOFTMAX, OpType.LAYER_NORM, OpType.BATCH_NORM ]: for op_ctr in [conv_op, dense_op]: graphs.append([ op_ctr(input_names=["input"], op_name="other"), new_op(op_name="output", op_type=op_type, input_names=["other"]) ]) graphs.append([ new_op(op_name="other", op_type=op_type, input_names=["input"]), op_ctr(input_names=["other"], op_name="output"), ]) input_tensor = {"input": jnp.ones((5, 5, 5, 5))} for graph in graphs: graph = new_graph(["input"], ["output"], graph) state = Model(graph).init(random.PRNGKey(1), input_tensor) # Make all the kernels positive, otherwise, the ReLU might zero out the # entire tensor. state = jax.tree_util.tree_map(abs, state) subg_model = SubgraphModel(graph, None, state, input_tensor) lp_abstract = linear.LinopProperty().infer(subg_model, abstract=True) lp_concrete = linear.LinopProperty().infer(subg_model, abstract=False) pairings_concerete = lp_concrete.pairings["output"][ "input"].mappings pairings_abstract = lp_abstract.pairings["output"][ "input"].mappings print("concrete:", pairings_concerete) print("abstract:", pairings_abstract) self.assertTrue( ((pairings_abstract - pairings_concerete) == 0).all())
def test_synthesizer_easy_one(self): """Replacing [conv3x3(features = 64)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp])
def test_synthesizer_two(self): """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp])
def make_subgraph_models(self, subgraph_spec, graphs=None): """Inserts the new subgraph_spec into the subgraph_models. The graphs argument can be used to pass in intermediate results of synthesis rather than completing synthesis all at once, e.g., we can call make_subgraph_models twice and pass the output of the first call as an argument to the second call. This allows us to break the synthesis down into multiple steps. Args: subgraph_spec: The new ops to insert. graphs: The graphs into which to insert the new ops. Returns: A sequence of subgraphs with subgraph_spec inserted. """ new_subgraph_models = [] if not graphs: graphs = [None] * len(self.subgraphs_and_props) for graph, (subgraph_model, _) in zip(graphs, self.subgraphs_and_props): graph = graph if graph else subgraph_model.graph constants = subgraph_model.constants state = subgraph_model.state inputs = subgraph_model.inputs new_graph = subgraph.replace_subgraph(graph, subgraph_spec) if not self.abstract: # concrete synthesis initializes the state while inheriting the parent # params new_model = Model(new_graph, constants) try: new_state = new_model.init(random.PRNGKey(0), inputs) except Exception as e: # pylint: disable=broad-except # catch everything else for now... this is the safest way to filter # out malformed subgraphs which will not initialize exc_type, exc_value, exc_traceback = sys.exc_info() logging.info( "%s", "".join( traceback.format_exception(exc_type, exc_value, exc_traceback))) raise ValueError( "Could not initialized malformed subgraph " f"({type(e).__name__}: {e}).") from e new_state = flax.core.unfreeze(new_state) inherited, frozen = subgraph.inherit_params( new_state["params"], state["params"]) new_state = {**inherited, **frozen} else: new_state = None new_subgraph_model = SubgraphModel(new_graph, constants, new_state, inputs, subgraph_spec) new_subgraph_models.append(new_subgraph_model) return new_subgraph_models
def test_satisfy(self): # This test removes the last dense layer, so the old graph should be more # deep. ops = self.graph.ops[:-1] ops[-1].name = "fc/logits" graph = new_graph(input_names=["input"], output_names=["fc/logits"], ops=ops) subgraph_model = SubgraphModel(graph, self.constants, {}, {}) dp = depth.DepthProperty().infer(subgraph_model) self.assertTrue(dp.verify(self.subgraph_model))
def setUp(self): super().setUp() self.graph, self.constants, _ = cnn.CifarNet() state = Model(self.graph, self.constants).init(random.PRNGKey(0), {"input": jnp.ones((5, 32, 32, 3))}) self.subgraph_model = SubgraphModel( self.graph, self.constants, state, {"input": jnp.ones((5, 32, 32, 3))}) self.sp = shape.ShapeProperty().infer(self.subgraph_model)
def test_synthesizer_hard(self): if not self.hard: return subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp])
def test_unsatisfy(self): # This test removes the last dense layer, so the new graph should have a # different shape (and therefore not satisfy the inferred shape property). graph = new_graph(input_names=["input"], output_names=["fc/relu"], ops=self.graph.ops) state = Model(graph, self.constants).init(random.PRNGKey(0), {"input": jnp.ones((5, 32, 32, 3))}) subgraph_model = SubgraphModel(graph, self.constants, state, {"input": jnp.ones((5, 32, 32, 3))}) self.assertFalse(self.sp.verify(subgraph_model))
def test_synthesizer_easy_one(self): """Replacing [conv3x3(features = 64)]. Because we do not test linear, this is replaced by dense3x3(features = 64) due to the enumeration order. """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp])
def test_synthesizer_two(self): """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)]. Because we do not check for the depth property, [dense(features = 64), avgpool2x2(strides=2x2)] works as well (which is what is synthesized due to the enumeration order). """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp])
def test_synthesizer_easy_two(self): """Replacing [conv3x3(features = 64)]. Because we test all three props, this is replaced by conv3x3(features = 64) (i.e., an identical op) due to the enumeration order. """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp])
def test_synthesizer_resnet_big(self): self.graph, self.constants, _ = resnetv1.ResNet18( num_classes=10, input_resolution="small") self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)[self.graph.output_names[0]] self.max_size = int(10e8) subg_ops = self.graph.ops[3:5] + self.graph.ops[8:12] subg = [subgraph.SubgraphNode(op=o) for o in subg_ops] subg[-1].output_names = [f"{subg[-1].op.name}:0"] subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp])
def mutate(self, graph, contexts, abstract=True): """Selects a subgraph and mutates its abstract properties. Note that this method does not mutate the graph, only the properties of a subgraph. Therefore the caller is responsible for synthesizing a subgraph satisfying the mutated properties. Args: graph: The input graph to mutate. contexts: A list of (constants, state, inputs) tuples, each specifying a different instantiation of the graph. abstract: Whether to infer the subgraph properties abstractly or concretely. Returns: For each instantiation, a subgraph model specifying the selected subgraph, and a sequence of abstract properties specifying the mutates properties. Raises: NotImplementedError: for concrete inference of properties. """ subgraph_spec = self.select_subgraph(graph) new_graph = replace_subgraph(graph, subgraph_spec) if not abstract: # need to initialize the model state if not abstract raise NotImplementedError models_and_props = [] to_mutate = random.randrange(len(contexts)) for idx, (constants, _, inputs) in enumerate(contexts): subgraph_model = SubgraphModel(new_graph, constants, state=None, inputs=inputs, subgraph=subgraph_spec) # Only mutate the properties of one randomly selected instance. The other # instances simply need to satisfy the shape property (which is never # mutated). # The alternative would be to make sure that all the properties are # mutated "in the same way" for every instance of the graph, but that # requires overly complex logic matching input and output names. if idx == to_mutate: new_properties = [ prop.infer(subgraph_model, abstract=abstract).mutate() for prop in self.properties ] else: new_properties = [] for prop in self.properties: if type(prop) is ShapeProperty: # pylint: disable=unidiomatic-typecheck new_properties.append(ShapeProperty().infer( subgraph_model, abstract=abstract)) models_and_props.append((subgraph_model, new_properties)) # match mutated shapes for prop in models_and_props[to_mutate][1]: if type(prop) is not ShapeProperty: continue # pylint: disable=unidiomatic-typecheck output_shapes = prop.output_shapes for _, other_props in models_and_props: for other_prop in other_props: if type(other_prop) is not ShapeProperty: continue # pylint: disable=unidiomatic-typecheck for k in list(other_prop.output_shapes.keys()): if k not in output_shapes: del other_prop.output_shapes[k] return models_and_props
def test_full(self): graph, constants, _ = cnn.CifarNet() subgraph_model = SubgraphModel(graph, constants, None, {"input": jnp.ones((10, 32, 32, 3))}) linear.LinopProperty().infer(subgraph_model)
def test_multi_input(self): ops = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu0", op_type=OpType.RELU, input_names=["dense0"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu1", op_type=OpType.RELU, input_names=["dense1"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu2", op_type=OpType.RELU, input_names=["dense2"]), new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), ] graph = new_graph(input_names=["input"], output_names=["add0", "add1"], ops=ops) subgraph_spec = [ SubgraphNode(op=new_op( op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])), SubgraphNode(op=new_op( op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])), SubgraphNode(op=new_op( op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])), SubgraphNode(op=new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), output_names=["add0"]), SubgraphNode(op=new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), output_names=["add1"]), ] replaced_graph = replace_subgraph(graph, subgraph_spec) subgraph_model = SubgraphModel(replaced_graph, {}, {}, {}, subgraph_spec) dp = depth.DepthProperty().infer(subgraph_model) depth_map = dp.depth_map self.assertLen(depth_map, 3) self.assertIn("dense0:0", depth_map) self.assertIn("dense1:0", depth_map) self.assertIn("dense2:0", depth_map) self.assertLen(depth_map["dense0:0"], 1) self.assertEqual(depth_map["dense0:0"]["add0:0"], 2) self.assertLen(depth_map["dense1:0"], 2) self.assertEqual(depth_map["dense1:0"]["add0:0"], 2) self.assertEqual(depth_map["dense1:0"]["add1:0"], 2) self.assertLen(depth_map["dense2:0"], 1) self.assertEqual(depth_map["dense2:0"]["add1:0"], 2)
def test_multi_input(self): ops = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu0", op_type=OpType.RELU, input_names=["dense0"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu1", op_type=OpType.RELU, input_names=["dense1"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu2", op_type=OpType.RELU, input_names=["dense2"]), new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), ] graph = new_graph(input_names=["input"], output_names=["add0", "add1"], ops=ops) subgraph_spec = [ SubgraphNode(op=new_op( op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])), SubgraphNode(op=new_op( op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])), SubgraphNode(op=new_op( op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])), SubgraphNode(op=new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), output_names=["add0"]), SubgraphNode(op=new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), output_names=["add1"]), ] replaced_graph = replace_subgraph(graph, subgraph_spec) inp = {"input": jnp.ones((10, 32, 32, 3))} subgraph_model = SubgraphModel(replaced_graph, {}, {}, inp, subgraph_spec) lp = linear.LinopProperty().infer(subgraph_model) pairings = lp.pairings self.assertLen(pairings, 2) self.assertIn("add0:0", pairings) self.assertLen(pairings["add0:0"], 2) self.assertIn("dense0:0", pairings["add0:0"]) self.assertIn("dense1:0", pairings["add0:0"]) self.assertIn("add1:0", pairings) self.assertLen(pairings["add1:0"], 2) self.assertIn("dense1:0", pairings["add1:0"]) self.assertIn("dense2:0", pairings["add1:0"])
def setUp(self): super().setUp() self.graph, self.constants, _ = cnn.CifarNet() self.subgraph_model = SubgraphModel(self.graph, self.constants, {}, {}) self.dp = depth.DepthProperty().infer(self.subgraph_model)