def test_rewire(self): subgraph_spec = [ SubgraphNode(op=new_op(op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), ), SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv/1"]), output_names=["conv_layer1/relu"]) ] graph = replace_subgraph(self.graph, subgraph_spec) state = Model(graph, self.constants).init(random.PRNGKey(0), {"input": jnp.ones((5, 32, 32, 3))}) subgraph_model = SubgraphModel(graph, self.constants, state, {"input": jnp.ones( (5, 32, 32, 3))}, subgraph_spec) sp = shape.ShapeProperty().infer(subgraph_model) self.assertLen(sp.input_shapes, 1) self.assertIn("conv_layer0/avg_pool:0", sp.input_shapes) self.assertLen(sp.output_shapes, 2) self.assertIn("conv_layer1/gelu/1:0", sp.output_shapes) self.assertIn("conv_layer1/relu:0", sp.output_shapes)
def setUp(self): super().setUp() self.graph_dense = new_graph( ["input"], ["output"], [ new_op( op_name="output", op_type=OpType.SOFTMAX, # op_kwargs={"features": 10}, input_names=["input"]) ]) state_dense = Model(self.graph_dense).init( random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))}) self.subgraph_dense = SubgraphModel(self.graph_dense, None, state_dense, {"input": jnp.ones((5, 5, 5))}) self.lp_dense = linear.LinopProperty().infer(self.subgraph_dense) self.graph_conv = new_graph(["input"], ["output"], [ new_op(op_name="output", op_type=OpType.CONV, op_kwargs={ "features": 10, "kernel_size": [3, 3] }, input_names=["input"]) ]) state_conv = Model(self.graph_conv).init( random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))}) self.subgraph_conv = SubgraphModel(self.graph_conv, None, state_conv, {"input": jnp.ones((5, 5, 5))}) self.lp_conv = linear.LinopProperty().infer(self.subgraph_conv)
def setUp(self): super().setUp() self.graph, self.constants, _ = cnn.CifarNet() self.model = Model(self.graph, self.constants) self.state = self.model.init(random.PRNGKey(0), jnp.ones( (1, 32, 32, 3))) self.subgraph = [ subgraph.SubgraphNode(op=new_op( op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), ), subgraph.SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv/1" ]), output_names=["conv_layer1/relu"]) ] self.new_graph = subgraph.replace_subgraph(self.graph, self.subgraph) self.new_model = Model(self.new_graph, self.constants) self.new_state = self.new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
def test_abstract_sequential_synthesizer_output_features(self): graph, constants, _ = cnn.CifarNet() subgraph_spec = [ SubgraphNode( op=new_op( op_name="conv_layer1/conv", op_type=OpType.CONV, op_kwargs={ "features": "S:-1*2", "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]),), SubgraphNode( op=new_op( op_name="conv_layer1/relu", op_type=OpType.RELU, input_names=["conv_layer1/conv"]), output_names=["conv_layer1/relu"]) ] subgraph = replace_subgraph(graph, subgraph_spec) subgraph_model = SubgraphModel(subgraph, constants, None, {"input": jnp.zeros((5, 32, 32, 10))}, subgraph_spec) sp = shape.ShapeProperty().infer(subgraph_model) syn = TestSequentialSynthesizer([(subgraph_model, [sp])], 0) self.assertEqual(syn.output_features_mul, 2) self.assertEqual(syn.output_features_div, 1)
def test_abstract_sequential_synthesizer_fail(self): graph, constants, _ = cnn.CifarNet() subgraph_spec = [ SubgraphNode( op=new_op( op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), output_names=["conv_layer1/conv"]), SubgraphNode( op=new_op( op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv"]), output_names=["conv_layer1/relu"]) ] subgraph = SubgraphModel(graph, constants, None, {"input": jnp.zeros((5, 32, 32, 10))}, subgraph_spec) self.assertRaisesRegex(ValueError, ".*exactly one input.*", TestSequentialSynthesizer, [(subgraph, [])], 0)
def test_rewire(self): # orig: conv, relu, pool, conv, relu, pool, flatten, dense, relu, dense # new: conv, relu, pool, conv, gelu, pool, flatten, dense, relu, dense subgraph_spec = [ SubgraphNode(op=new_op(op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), ), SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv/1"]), output_names=["conv_layer1/relu"]) ] graph = replace_subgraph(self.graph, subgraph_spec) subgraph_model = SubgraphModel(graph, self.constants, {}, {}, subgraph_spec) dp = depth.DepthProperty().infer(subgraph_model) depth_map = dp.depth_map self.assertLen(depth_map, 1) self.assertIn("conv_layer0/avg_pool:0", depth_map) self.assertLen(depth_map["conv_layer0/avg_pool:0"], 2) self.assertIn("conv_layer1/relu:0", depth_map["conv_layer0/avg_pool:0"]) self.assertEqual( depth_map["conv_layer0/avg_pool:0"]["conv_layer1/relu:0"], 1) self.assertIn("conv_layer1/gelu/1:0", depth_map["conv_layer0/avg_pool:0"]) self.assertEqual( depth_map["conv_layer0/avg_pool:0"]["conv_layer1/gelu/1:0"], 1)
def test_abstract(self): graphs = [] conv_op = functools.partial(new_op, op_type=OpType.CONV, op_kwargs={ "features": 10, "kernel_size": [3, 3] }) dense_op = functools.partial(new_op, op_type=OpType.DENSE, op_kwargs={ "features": 10, }) for op_type in [ OpType.RELU, OpType.SOFTMAX, OpType.LAYER_NORM, OpType.BATCH_NORM ]: for op_ctr in [conv_op, dense_op]: graphs.append([ op_ctr(input_names=["input"], op_name="other"), new_op(op_name="output", op_type=op_type, input_names=["other"]) ]) graphs.append([ new_op(op_name="other", op_type=op_type, input_names=["input"]), op_ctr(input_names=["other"], op_name="output"), ]) input_tensor = {"input": jnp.ones((5, 5, 5, 5))} for graph in graphs: graph = new_graph(["input"], ["output"], graph) state = Model(graph).init(random.PRNGKey(1), input_tensor) # Make all the kernels positive, otherwise, the ReLU might zero out the # entire tensor. state = jax.tree_util.tree_map(abs, state) subg_model = SubgraphModel(graph, None, state, input_tensor) lp_abstract = linear.LinopProperty().infer(subg_model, abstract=True) lp_concrete = linear.LinopProperty().infer(subg_model, abstract=False) pairings_concerete = lp_concrete.pairings["output"][ "input"].mappings pairings_abstract = lp_abstract.pairings["output"][ "input"].mappings print("concrete:", pairings_concerete) print("abstract:", pairings_abstract) self.assertTrue( ((pairings_abstract - pairings_concerete) == 0).all())
def conv_net(in_features, out_features, num_classes, blocks=None): """Graph for 3-layer CNN.""" if not blocks: blocks = [block_type() for block_type in BLOCK_TYPES] input_name = "input" new_blocks = [] ops = [ new_op(op_name="proj", op_type=OpType.CONV, op_kwargs={ "features": in_features, "kernel_size": 1, }, input_names=[input_name]) ] constants = {} block_input_name = ops[-1].name for idx, block in enumerate(blocks): block = block.instantiate(input_names=[block_input_name], instance_id=idx) new_blocks.append(block) constants.update(block.constants) ops.extend(block.graph.ops) block_input_name = ops[-1].name constants.update({ "out_features": out_features, "num_classes": num_classes }) ops.extend([ new_op(op_name="flatten", op_type=OpType.FLATTEN, input_names=[ops[-1].name]), new_op(op_name="fc/dense", op_type=OpType.DENSE, op_kwargs={"features": "K:out_features"}, input_names=["flatten"]), new_op(op_name="fc/relu", op_type=OpType.RELU, input_names=["fc/dense"]), new_op(op_name="fc/logits", op_type=OpType.DENSE, op_kwargs={"features": "K:num_classes"}, input_names=["fc/relu"]) ]) graph = new_graph(input_names=[input_name], output_names=["fc/logits"], ops=ops) return graph, constants, new_blocks
def synthesize(self): """Returns a new subgraph.""" subgraph_spec = self.subgraphs_and_props[0][0].subgraph subg_ops = [copy.deepcopy(node.op) for node in subgraph_spec] mutations = [ self.delete, self.insert, self.mutate_field, lambda x: self.insert(self.delete(x)), self.swap] if self.use_automl_zero: mutations.append(lambda _: self.randomize()) # Certain mutations may not be applicable for the selected subgraph, and # they will return None (e.g., if the subgraph is of size 1, we cannot # swap). So loop through all mutations in a random order until we find a # mutation that is applicable. random.shuffle(mutations) mutated_subg_ops = None while mutations and mutated_subg_ops is None: mutation = mutations.pop() mutated_subg_ops = mutation(subg_ops) if mutated_subg_ops is None: raise ValueError("Synthesis failed.") subg_ops = mutated_subg_ops prefix = f"gen{self.generation}/" if not subg_ops: subg_ops.append(new_op("dummy", OpType.IDENTITY, [self.input_name])) for op in subg_ops: op.name = f"{prefix}{op.type.name.lower()}" subgraph_spec = self.make_subgraph_spec(subg_ops) return self.make_subgraph_models(subgraph_spec)
def test_identical(self): """Tests whether the fingerprint is the same for identical graphs.""" ops = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="output", op_type=OpType.ADD, input_names=["dense0", "dense1"]), ] graph = new_graph(["input"], ["output"], ops) input_dict = {"input": jnp.ones((5, 5, 5))} fingerprint1 = fingerprint.fingerprint_graph(graph, {}, input_dict) fingerprint2 = fingerprint.fingerprint_graph(graph, {}, input_dict) self.assertEqual(fingerprint1, fingerprint2)
def test_multi_input_output(self): """Tests a subgraph substitution on a graph with multiple inputs / output ops. We use a ResNet model, which has skip connections. This test checks that the substitution produces the expected number of ops, and also that the newly produced graph is still executable. """ graph, constants, _ = resnetv1.ResNet18(num_classes=10, input_resolution="small") model = Model(graph, constants) state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) y = model.apply(state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10)) subg = [ subgraph.SubgraphNode( op=new_op(op_name="subgraph/conv0", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["resnet11/skip/relu1"])), subgraph.SubgraphNode( op=new_op(op_name="subgraph/gelu1", op_type=OpType.GELU, input_names=["subgraph/conv0"]), output_names=["resnet_stride1_filtermul1_basic12/relu2"]) ] new_graph = subgraph.replace_subgraph(graph, subg) # the subgraph is 2 ops (conv / gelu) replacing 3 ops (conv / bn / relu) self.assertLen(graph.ops, len(new_graph.ops) + 1) new_model = Model(new_graph, constants) new_state = new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) y = new_model.apply(new_state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10))
def conv_block(): """Makes a conv block parameterized by the number of features.""" ops = [ new_op(op_name="conv", op_type=OpType.CONV, op_kwargs={ "features": "S:-1*2", "kernel_size": 3 }, input_names=["input"]), new_op(op_name="relu", op_type=OpType.RELU, input_names=["conv"]), new_op(op_name="avg_pool", op_type=OpType.AVG_POOL, input_names=["relu"], input_kwargs={ "window_shape": 2, "strides": 2 }), ] graph = new_graph(input_names=["input"], output_names=["avg_pool"], ops=ops) return Block(name="conv_layer", graph=graph)
def append_op(ops, op_name, op_type, input_names = None, input_kwargs = None, op_kwargs = None, num_outputs = 1): """Convenience function for append to a sequence of ops.""" if input_names is None: input_names = [ops[-1].name] ops.append( new_op(op_name=op_name, op_type=op_type, input_names=input_names, input_kwargs=input_kwargs, op_kwargs=op_kwargs if op_kwargs else {}, num_outputs=num_outputs))
def append_op( ops, # pylint: disable=dangerous-default-value op_name, op_type, input_names=None, input_kwargs=None, op_kwargs={}, num_outputs=1): """Convenience function for append to a sequence of ops.""" if not input_names: input_names = [ops[-1].name] default_op_kwargs = DEFAULT_OP_KWARGS.get(op_type, {}) ops.append( new_op(op_name=op_name, op_type=op_type, input_names=input_names, input_kwargs=input_kwargs, op_kwargs={ **default_op_kwargs, **op_kwargs }, num_outputs=num_outputs))
def test_equal(self): """Tests whether the fingerprint is the same for equivalent graphs. The ops have different names and also have different topological sort. """ ops1 = [ new_op(op_name="dense", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="conv", op_type=OpType.CONV, op_kwargs={ "features": 32, "kernel_size": [3] }, input_names=["input"]), new_op(op_name="output", op_type=OpType.ADD, input_names=["dense", "conv"]), ] graph1 = new_graph(["input"], ["output"], ops1) ops2 = [ new_op(op_name="conv2", op_type=OpType.CONV, op_kwargs={ "features": 32, "kernel_size": [3] }, input_names=["input"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="output", op_type=OpType.ADD, input_names=["dense2", "conv2"]), ] graph2 = new_graph(["input"], ["output"], ops2) input_dict = {"input": jnp.ones((5, 5, 5))} fingerprint1 = fingerprint.fingerprint_graph(graph1, {}, input_dict) fingerprint2 = fingerprint.fingerprint_graph(graph2, {}, input_dict) self.assertEqual(fingerprint1, fingerprint2)
def test_not_equal(self): """Tests whether the fingerprint is different for non-equivalent graphs.""" ops1 = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="output", op_type=OpType.ADD, input_names=["dense0", "dense1"]), ] graph1 = new_graph(["input"], ["output"], ops1) ops2 = [ new_op(op_name="conv2", op_type=OpType.CONV, op_kwargs={ "features": 32, "kernel_size": [3] }, input_names=["input"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="output", op_type=OpType.ADD, input_names=["dense2", "conv2"]), ] graph2 = new_graph(["input"], ["output"], ops2) input_dict = {"input": jnp.ones((5, 5, 5))} fingerprint1 = fingerprint.fingerprint_graph(graph1, {}, input_dict) fingerprint2 = fingerprint.fingerprint_graph(graph2, {}, input_dict) self.assertNotEqual(fingerprint1, fingerprint2)
def test_multi_input(self): ops = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu0", op_type=OpType.RELU, input_names=["dense0"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu1", op_type=OpType.RELU, input_names=["dense1"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu2", op_type=OpType.RELU, input_names=["dense2"]), new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), ] graph = new_graph(input_names=["input"], output_names=["add0", "add1"], ops=ops) subgraph_spec = [ SubgraphNode(op=new_op( op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])), SubgraphNode(op=new_op( op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])), SubgraphNode(op=new_op( op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])), SubgraphNode(op=new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), output_names=["add0"]), SubgraphNode(op=new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), output_names=["add1"]), ] replaced_graph = replace_subgraph(graph, subgraph_spec) inp = {"input": jnp.ones((10, 32, 32, 3))} subgraph_model = SubgraphModel(replaced_graph, {}, {}, inp, subgraph_spec) lp = linear.LinopProperty().infer(subgraph_model) pairings = lp.pairings self.assertLen(pairings, 2) self.assertIn("add0:0", pairings) self.assertLen(pairings["add0:0"], 2) self.assertIn("dense0:0", pairings["add0:0"]) self.assertIn("dense1:0", pairings["add0:0"]) self.assertIn("add1:0", pairings) self.assertLen(pairings["add1:0"], 2) self.assertIn("dense1:0", pairings["add1:0"]) self.assertIn("dense2:0", pairings["add1:0"])
def instantiate(self, input_names, instance_id=None, constants=None): """Instantiates a version of the block with unique names. This method uses the names of graph and constants from the initial definition of the block (__init__) , so that one can instantiate from any derived block with same effect, e.g., if we have: init_block = block.__init__(name="conv_layer", ...) block0 = init_block.instantiate(instance_id=0, ...) then: block1 = init_block.instantiate(instance_id=1, ...) will have the same effect as: block1 = block0.instantiate(instance_id=1, ...) The one caveat is that the default values for unspecified constants are inherited from the instantiating block (instead of the initial definition). Args: input_names: The input tensor names the instantiated block will consume. instance_id: An id to make the names in the instantiated block unique. The id should be unique within a graph. constants: Updated parameters for the instantiated block. Returns: An instantiated block. Raises: ValueError: if the number of input names provided does not equal the number of inputs consumed by the graph. """ if len(input_names) != len(self.base_graph.input_names): raise ValueError("Wrong number of inputs provided.") prefix = "" if self.name: prefix += self.name if instance_id is not None: prefix += str(instance_id) if prefix: prefix += "/" if not constants: constants = dict(self.base_constants) new_input_names = input_names updated_names = { o: n for o, n in zip(self.base_graph.input_names, new_input_names) } inputs_names = [ canonicalize_tensor_name(n) for n in self.base_graph.input_names ] updated_names.update( {o: n for o, n in zip(inputs_names, new_input_names)}) # Update ops. new_ops = [] for op in self.base_graph.ops: # Update all input tensor names. # Any internal inputs (i.e., anything that is not a graph input) needs to # be updated with the prefix. new_inputs = [] for inp in op.input_names: try: idx = inputs_names.index(inp) new_inputs.append(new_input_names[idx]) except ValueError: new_inputs.append(f"{prefix}{inp}") # Update symbolic constant names in input_kwargs and op_kwargs. new_kwargs = [] for kwargs in [op.input_kwargs, op.op_kwargs]: nk = { k: _prefix_symbolic(v, prefix, constants, updated_names) for k, v in kwargs.items() } new_kwargs.append(nk) new_ops.append( new_op(op_name=f"{prefix}{op.name}", op_type=op.type, input_names=new_inputs, input_kwargs=new_kwargs[0], op_kwargs=new_kwargs[1], num_outputs=op.num_outputs)) # Update constants and prefix symbolic constant names. old_constants = dict(self.base_constants) if constants: old_constants.update(constants) new_constants = {f"{prefix}{k}": v for k, v in old_constants.items()} # Prefix graph output names. new_output_names = [ f"{prefix}{on}" for on in self.base_graph.output_names ] graph = new_graph(ops=new_ops, input_names=new_input_names, output_names=new_output_names) return Block(name=self.name, graph=graph, constants=new_constants, base_graph=self.base_graph, base_constants=old_constants)
def op_enumerator( cls, prefix=None, kwarg_defaults=None, full=True, op_types=None, ): if not prefix: prefix = "" elif not prefix.endswith("/"): prefix = f"{prefix}/" kwarg_defaults = cls.make_default_kwargs(kwarg_defaults, full) if op_types is None: op_types = OpType for op_type in op_types: name = f"{prefix}{op_type.name.lower()}" inputs = ["inputs"] if op_type in [ OpType.IDENTITY, OpType.NONE, OpType.DENSE_GENERAL, OpType.ADD, OpType.MUL, OpType.SCALAR_ADD, OpType.DOT_GENERAL, OpType.EINSUM, OpType.FLATTEN, OpType.RESHAPE, OpType.TRANSPOSE, OpType.PARAM, OpType.SELF_ATTENTION, OpType.STOCH_DEPTH, OpType.MEAN, ]: # Not supported for synthesis. pass elif op_type in [ OpType.SCALAR_MUL, OpType.BATCH_NORM, OpType.LAYER_NORM, OpType.RELU, OpType.GELU, OpType.SWISH, OpType.SIGMOID, OpType.SOFTMAX, ]: # No kwargs. yield new_op(name, op_type, inputs) elif op_type in [ OpType.DENSE, OpType.CONV, OpType.GROUP_NORM, OpType.AVG_POOL, OpType.MAX_POOL, OpType.DROPOUT, ]: op_kwargs_dict, input_kwargs_dict = cls.all_kwargs_for_op_type( kwarg_defaults, full, op_type) for op_kwargs, input_kwargs in cls.kwargs_for_op_to_product( op_kwargs_dict, input_kwargs_dict): yield new_op(name, op_type, inputs, op_kwargs=op_kwargs, input_kwargs=input_kwargs) else: assert False, f"op_type {op_type} not supported" return
def test_multi_input(self): ops = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu0", op_type=OpType.RELU, input_names=["dense0"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu1", op_type=OpType.RELU, input_names=["dense1"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu2", op_type=OpType.RELU, input_names=["dense2"]), new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), ] graph = new_graph(input_names=["input"], output_names=["add0", "add1"], ops=ops) subgraph_spec = [ SubgraphNode(op=new_op( op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])), SubgraphNode(op=new_op( op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])), SubgraphNode(op=new_op( op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])), SubgraphNode(op=new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), output_names=["add0"]), SubgraphNode(op=new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), output_names=["add1"]), ] replaced_graph = replace_subgraph(graph, subgraph_spec) subgraph_model = SubgraphModel(replaced_graph, {}, {}, {}, subgraph_spec) dp = depth.DepthProperty().infer(subgraph_model) depth_map = dp.depth_map self.assertLen(depth_map, 3) self.assertIn("dense0:0", depth_map) self.assertIn("dense1:0", depth_map) self.assertIn("dense2:0", depth_map) self.assertLen(depth_map["dense0:0"], 1) self.assertEqual(depth_map["dense0:0"]["add0:0"], 2) self.assertLen(depth_map["dense1:0"], 2) self.assertEqual(depth_map["dense1:0"]["add0:0"], 2) self.assertEqual(depth_map["dense1:0"]["add1:0"], 2) self.assertLen(depth_map["dense2:0"], 1) self.assertEqual(depth_map["dense2:0"]["add1:0"], 2)
def resolve_op(self, op, intermediate_values, **_): """Resolves an op with possibly symbolic arguments to a concrete op.""" op_name = op.name.lower() op_type = op.type input_names = op.input_names input_values = [ intermediate_values[key.lower()] for key in input_names ] input_kwargs: Dict[str, Any] = op.input_kwargs op_kwargs: Dict[str, Any] = op.op_kwargs op_kwargs["name"] = op_name if op_type == OpType.NONE: pass elif op_type == OpType.IDENTITY: pass # nn.linear elif op_type == OpType.DENSE: _kv_resolve_symbolic(op_kwargs, ["kernel_init", "bias_init"]) _kv_resolve_symbolic(op_kwargs, ["features"], input_values, intermediate_values) elif op_type == OpType.DENSE_GENERAL: _kv_to_int(op_kwargs, ["axis", "batch_dims"]) _kv_resolve_symbolic(op_kwargs, ["kernel_init", "bias_init"]) _kv_resolve_symbolic(op_kwargs, ["features"], input_values, intermediate_values) elif op_type == OpType.CONV: _kv_to_int(op_kwargs, [ "kernel_size", "strides", "input_dilation", "kernel_dilation", "padding", ]) _kv_resolve_symbolic(op_kwargs, ["kernel_init", "bias_init"]) _kv_resolve_symbolic(op_kwargs, ["features", "feature_group_count"], input_values, intermediate_values) # others elif op_type == OpType.ADD: _kv_to_float(op_kwargs, ["layer_drop_rate"]) elif op_type == OpType.SCALAR_ADD: _kv_to_float(input_kwargs, ["const"]) elif op_type == OpType.MUL: pass elif op_type == OpType.SCALAR_MUL: _kv_to_float(input_kwargs, ["const"]) elif op_type == OpType.DOT_GENERAL: _kv_to_int(input_kwargs, ["dimension_numbers"]) elif op_type == OpType.EINSUM: pass # nn.attention elif op_type == OpType.SELF_ATTENTION: _kv_resolve_symbolic(op_kwargs, ["kernel_init", "bias_init"]) _kv_resolve_symbolic(op_kwargs, ["num_heads", "qkv_features", "out_features"], input_values, intermediate_values) # nn.activation elif op_type in [ OpType.RELU, OpType.GELU, OpType.SWISH, OpType.SIGMOID ]: pass elif op_type == OpType.SOFTMAX: _kv_to_int(input_kwargs, ["axis"]) # nn.normalization elif op_type == OpType.BATCH_NORM: _kv_to_int(op_kwargs, ["axis"]) _kv_resolve_symbolic(op_kwargs, ["scale_init", "bias_init"]) elif op_type == OpType.LAYER_NORM: pass elif op_type == OpType.GROUP_NORM: _kv_resolve_symbolic(op_kwargs, ["num_groups", "group_size"], input_values, intermediate_values) # reshape operators elif op_type == OpType.RESHAPE: _kv_resolve_symbolic(input_kwargs, ["new_shape"], input_values, intermediate_values) _kv_to_int(input_kwargs, ["new_shape"]) elif op_type == OpType.FLATTEN: pass elif op_type == OpType.TRANSPOSE: _kv_to_int(input_kwargs, ["axes"]) # nn.stochastic elif op_type == OpType.DROPOUT: _kv_to_int(op_kwargs, ["broadcast_dims"]) _kv_to_float(op_kwargs, ["rate"]) elif op_type == OpType.STOCH_DEPTH: _kv_to_float(op_kwargs, ["layer_drop_rate"]) # nn.pooling elif op_type == OpType.AVG_POOL: _kv_to_int(input_kwargs, ["window_shape", "strides"]) elif op_type == OpType.MAX_POOL: _kv_to_int(input_kwargs, ["window_shape", "strides"]) elif op_type == OpType.MEAN: _kv_to_int(input_kwargs, ["axis"]) # new param elif op_type == OpType.PARAM: _kv_to_int(input_kwargs, ["shape"]) _kv_resolve_symbolic(input_kwargs, ["shape", "init_fn"], input_values, intermediate_values) else: raise ValueError(f"op_type {op_type} not supported...") return new_op(op_name, op_type, input_names, input_kwargs, op_kwargs, num_outputs=op.num_outputs)