def test_rewire(self): # orig: conv, relu, pool, conv, relu, pool, flatten, dense, relu, dense # new: conv, relu, pool, conv, gelu, pool, flatten, dense, relu, dense subgraph_spec = [ SubgraphNode(op=new_op(op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), ), SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv/1"]), output_names=["conv_layer1/relu"]) ] graph = replace_subgraph(self.graph, subgraph_spec) subgraph_model = SubgraphModel(graph, self.constants, {}, {}, subgraph_spec) dp = depth.DepthProperty().infer(subgraph_model) depth_map = dp.depth_map self.assertLen(depth_map, 1) self.assertIn("conv_layer0/avg_pool:0", depth_map) self.assertLen(depth_map["conv_layer0/avg_pool:0"], 2) self.assertIn("conv_layer1/relu:0", depth_map["conv_layer0/avg_pool:0"]) self.assertEqual( depth_map["conv_layer0/avg_pool:0"]["conv_layer1/relu:0"], 1) self.assertIn("conv_layer1/gelu/1:0", depth_map["conv_layer0/avg_pool:0"]) self.assertEqual( depth_map["conv_layer0/avg_pool:0"]["conv_layer1/gelu/1:0"], 1)
def test_rewire(self): subgraph_spec = [ SubgraphNode(op=new_op(op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), ), SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv/1"]), output_names=["conv_layer1/relu"]) ] graph = replace_subgraph(self.graph, subgraph_spec) state = Model(graph, self.constants).init(random.PRNGKey(0), {"input": jnp.ones((5, 32, 32, 3))}) subgraph_model = SubgraphModel(graph, self.constants, state, {"input": jnp.ones( (5, 32, 32, 3))}, subgraph_spec) sp = shape.ShapeProperty().infer(subgraph_model) self.assertLen(sp.input_shapes, 1) self.assertIn("conv_layer0/avg_pool:0", sp.input_shapes) self.assertLen(sp.output_shapes, 2) self.assertIn("conv_layer1/gelu/1:0", sp.output_shapes) self.assertIn("conv_layer1/relu:0", sp.output_shapes)
def test_abstract_sequential_synthesizer_output_features(self): graph, constants, _ = cnn.CifarNet() subgraph_spec = [ SubgraphNode( op=new_op( op_name="conv_layer1/conv", op_type=OpType.CONV, op_kwargs={ "features": "S:-1*2", "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]),), SubgraphNode( op=new_op( op_name="conv_layer1/relu", op_type=OpType.RELU, input_names=["conv_layer1/conv"]), output_names=["conv_layer1/relu"]) ] subgraph = replace_subgraph(graph, subgraph_spec) subgraph_model = SubgraphModel(subgraph, constants, None, {"input": jnp.zeros((5, 32, 32, 10))}, subgraph_spec) sp = shape.ShapeProperty().infer(subgraph_model) syn = TestSequentialSynthesizer([(subgraph_model, [sp])], 0) self.assertEqual(syn.output_features_mul, 2) self.assertEqual(syn.output_features_div, 1)
def test_abstract_sequential_synthesizer_fail(self): graph, constants, _ = cnn.CifarNet() subgraph_spec = [ SubgraphNode( op=new_op( op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), output_names=["conv_layer1/conv"]), SubgraphNode( op=new_op( op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv"]), output_names=["conv_layer1/relu"]) ] subgraph = SubgraphModel(graph, constants, None, {"input": jnp.zeros((5, 32, 32, 10))}, subgraph_spec) self.assertRaisesRegex(ValueError, ".*exactly one input.*", TestSequentialSynthesizer, [(subgraph, [])], 0)
def test_multi_input(self): ops = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu0", op_type=OpType.RELU, input_names=["dense0"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu1", op_type=OpType.RELU, input_names=["dense1"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu2", op_type=OpType.RELU, input_names=["dense2"]), new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), ] graph = new_graph(input_names=["input"], output_names=["add0", "add1"], ops=ops) subgraph_spec = [ SubgraphNode(op=new_op( op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])), SubgraphNode(op=new_op( op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])), SubgraphNode(op=new_op( op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])), SubgraphNode(op=new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), output_names=["add0"]), SubgraphNode(op=new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), output_names=["add1"]), ] replaced_graph = replace_subgraph(graph, subgraph_spec) subgraph_model = SubgraphModel(replaced_graph, {}, {}, {}, subgraph_spec) dp = depth.DepthProperty().infer(subgraph_model) depth_map = dp.depth_map self.assertLen(depth_map, 3) self.assertIn("dense0:0", depth_map) self.assertIn("dense1:0", depth_map) self.assertIn("dense2:0", depth_map) self.assertLen(depth_map["dense0:0"], 1) self.assertEqual(depth_map["dense0:0"]["add0:0"], 2) self.assertLen(depth_map["dense1:0"], 2) self.assertEqual(depth_map["dense1:0"]["add0:0"], 2) self.assertEqual(depth_map["dense1:0"]["add1:0"], 2) self.assertLen(depth_map["dense2:0"], 1) self.assertEqual(depth_map["dense2:0"]["add1:0"], 2)
def test_multi_input(self): ops = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu0", op_type=OpType.RELU, input_names=["dense0"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu1", op_type=OpType.RELU, input_names=["dense1"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu2", op_type=OpType.RELU, input_names=["dense2"]), new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), ] graph = new_graph(input_names=["input"], output_names=["add0", "add1"], ops=ops) subgraph_spec = [ SubgraphNode(op=new_op( op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])), SubgraphNode(op=new_op( op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])), SubgraphNode(op=new_op( op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])), SubgraphNode(op=new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), output_names=["add0"]), SubgraphNode(op=new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), output_names=["add1"]), ] replaced_graph = replace_subgraph(graph, subgraph_spec) inp = {"input": jnp.ones((10, 32, 32, 3))} subgraph_model = SubgraphModel(replaced_graph, {}, {}, inp, subgraph_spec) lp = linear.LinopProperty().infer(subgraph_model) pairings = lp.pairings self.assertLen(pairings, 2) self.assertIn("add0:0", pairings) self.assertLen(pairings["add0:0"], 2) self.assertIn("dense0:0", pairings["add0:0"]) self.assertIn("dense1:0", pairings["add0:0"]) self.assertIn("add1:0", pairings) self.assertLen(pairings["add1:0"], 2) self.assertIn("dense1:0", pairings["add1:0"]) self.assertIn("dense2:0", pairings["add1:0"])
def select_subgraph(self, graph): """Selects a subgraph for mutation.""" ops_to_add = list(graph.ops) subgraph_ops = [] produced_outputs = [] consumed_inputs = [] max_len = self.max_len if max_len <= 0: max_len = math.ceil(self.max_perc * len(graph.ops)) max_len = max(max_len, 1) # Sampling loop. while True: # Select a random op and add it to the current subgraph. op = random.choice(ops_to_add) subgraph_ops.append(op) # Update the outputs produced and the inputs consumed by the current # subgraph. for idx in range(op.num_outputs): produced_outputs.append(f"{op.name}:{idx}") for input_name in op.input_names: consumed_inputs.append(input_name) ops_to_add = [] # Find all ops which are neighbors of subgraph_ops (which is the current # subgraph). for op in graph.ops: # Skip any ops which are already in the subgraph. if op in subgraph_ops: continue # If any of the outputs of op are consumed by the subgraph, it is a # neighbor. for idx in range(op.num_outputs): if f"{op.name}:{idx}" in consumed_inputs: ops_to_add.append(op) break if op in ops_to_add: continue # If any of the inputs of op are produced by the subgraph, it is a # neighbor. for input_name in op.input_names: if input_name in produced_outputs: ops_to_add.append(op) break # Break if maximum length of subgraph has been reached. if max_len > 0 and len(subgraph_ops) >= max_len: break # Break with probability (1-p). if random.random() > self.p: break # Get all externally visible outputs, i.e., all tensors which are produced # by ops inside the subgraph, and consumed by ops outside the subgraph. externally_visible_outputs = { canonicalize_tensor_name(n) for n in graph.output_names } for op in graph.ops: if op in subgraph_ops: continue for input_name in op.input_names: if input_name in produced_outputs: externally_visible_outputs.add(input_name) # Create the subgraph spec. # N.B. adding the subgraph_ops in order by graph.ops preserves the # topological sort. subgraph_spec = [] for op in graph.ops: if op not in subgraph_ops: continue output_names = [] for idx in range(op.num_outputs): if f"{op.name}:{idx}" in externally_visible_outputs: output_names.append(f"{op.name}:{idx}") else: output_names.append(None) subg_node = SubgraphNode(op, output_names=output_names) subgraph_spec.append(subg_node) return subgraph_spec
def select_sequential_subgraph(self, graph): max_len = self.max_len if max_len <= 0: max_len = math.ceil(self.max_perc * len(graph.ops)) max_len = max(max_len, 1) if self.p == 0: # uniform random from [1, max_len] max_len = random.randrange(max_len) + 1 # first, filter for ops with a single input and a single output # or, ops with a single input and multi outputs (if allowed) ops = [ op for op in graph.ops if (len(op.input_names) == 1 and ( op.num_outputs == 1 or self.allow_multi_outputs)) ] ops = [op for op in ops if op.type not in UNSUPPORTED_OP_TYPES] if not ops: raise ValueError("No sequential subgraphs exist.") # start with a random op op = random.choice(ops) subgraph_ops = [op] idx = ops.index(op) next_idx = idx + 1 next_ok = next_idx < len(ops) prev_idx = idx - 1 prev_ok = prev_idx >= 0 # sampling loop while True: # break if maximum length of subgraph has been reached if max_len > 0 and len(subgraph_ops) >= max_len: break # break with probability (1-p) if self.p > 0 and random.random() > self.p: break # check if output of previous node is input to first node of subgraph if prev_ok and prev_idx >= 0: if f"{ops[prev_idx].name}:0" != subgraph_ops[0].input_names[0]: prev_ok = False if prev_idx < 0: prev_ok = False # check if output of last node of subgraph node is input to next node if next_ok and next_idx < len(ops): if f"{subgraph_ops[-1].name}:0" != ops[next_idx].input_names[0]: next_ok = False if next_idx >= len(ops): next_ok = False if not prev_ok and not next_ok: break # either previous node is not an option, or # both nodes are an option, then we randomly select the next node if (not prev_ok) or (prev_ok and next_ok and random.random() < 0.5): subgraph_ops.append(ops[next_idx]) next_idx += 1 # otherwise, we take the previous node else: subgraph_ops.insert(0, ops[prev_idx]) prev_idx -= 1 # now create the subgraph spec subgraph_spec = [SubgraphNode(op) for op in subgraph_ops] output_op = subgraph_ops[-1] # select one of the outputs of the output_op to rewire # we need to make sure that the output selected is actually consumed idxs = set() # add outputs that are consumed by other ops for op in graph.ops: for idx in range(output_op.num_outputs): for input_name in op.input_names: if input_name == f"{output_op.name}:{idx}": idxs.add(idx) # add outputs that are externally visible (i.e., outputs of the graph) for output_name in graph.output_names: if output_name == output_op.name: idxs.add(0) else: for idx in range(output_op.num_outputs): if output_name == f"{output_op.name}:{idx}": idxs.add(idx) output_idx = random.choice(list(idxs)) subgraph_spec[-1].output_names = [f"{output_op.name}:{output_idx}"] return subgraph_spec