def _update_pairings(self, op, in_shapes): """Updates pairings with the property for a single op. Args: op: The op for which to infer the pairing property. in_shapes: the shapes of the input tensors. """ assert len(op.input_names) == len(in_shapes) input_values = { input_name: jnp.ones(in_shape) for input_name, in_shape in zip(op.input_names, in_shapes) } output_names = [f"{op.name}:{i}" for i in range(op.num_outputs)] graph = new_graph(op.input_names, output_names, [op]) model = Model(graph) state = model.init(jax.random.PRNGKey(0), input_values) pairings = GraphPairings.infer( model, input_values, state, abstract=False).pairings new_pairings = {} for output_idx, output_name in enumerate(output_names): new_pairings[output_idx] = {} for input_idx, input_name in enumerate(op.input_names): new_pairings[output_idx][input_idx] = pairings[output_name][input_name] key = self.hash(op, in_shapes) self.pairings[key] = new_pairings
def test_params_b1(self): graph, constants, _ = efficientnet.EfficientNetB1(num_classes=1000) model = Model(graph, constants) state = model.init(random.PRNGKey(0), jnp.ones((1, 240, 240, 3))) params = state["params"] num_params = compute_num_params(params) self.assertEqual(num_params, EFFICIENTNET_PARAMS["b1"])
def make_subgraph_models(self, subgraph_spec, graphs=None): """Inserts the new subgraph_spec into the subgraph_models. The graphs argument can be used to pass in intermediate results of synthesis rather than completing synthesis all at once, e.g., we can call make_subgraph_models twice and pass the output of the first call as an argument to the second call. This allows us to break the synthesis down into multiple steps. Args: subgraph_spec: The new ops to insert. graphs: The graphs into which to insert the new ops. Returns: A sequence of subgraphs with subgraph_spec inserted. """ new_subgraph_models = [] if not graphs: graphs = [None] * len(self.subgraphs_and_props) for graph, (subgraph_model, _) in zip(graphs, self.subgraphs_and_props): graph = graph if graph else subgraph_model.graph constants = subgraph_model.constants state = subgraph_model.state inputs = subgraph_model.inputs new_graph = subgraph.replace_subgraph(graph, subgraph_spec) if not self.abstract: # concrete synthesis initializes the state while inheriting the parent # params new_model = Model(new_graph, constants) try: new_state = new_model.init(random.PRNGKey(0), inputs) except Exception as e: # pylint: disable=broad-except # catch everything else for now... this is the safest way to filter # out malformed subgraphs which will not initialize exc_type, exc_value, exc_traceback = sys.exc_info() logging.info( "%s", "".join( traceback.format_exception(exc_type, exc_value, exc_traceback))) raise ValueError( "Could not initialized malformed subgraph " f"({type(e).__name__}: {e}).") from e new_state = flax.core.unfreeze(new_state) inherited, frozen = subgraph.inherit_params( new_state["params"], state["params"]) new_state = {**inherited, **frozen} else: new_state = None new_subgraph_model = SubgraphModel(new_graph, constants, new_state, inputs, subgraph_spec) new_subgraph_models.append(new_subgraph_model) return new_subgraph_models
def _synthesize(self, subg, props): synthesizer = EnumerativeSequentialSynthesizer([(subg, props)], 0, max_len=3) subg = synthesizer.synthesize()[0] m = Model(subg.graph, self.constants) state = m.init(random.PRNGKey(0), self.input) out = m.apply(state, self.input)["fc/logits"] self.assertTrue((out != self.out).any())
def test_multi_input_output(self): """Tests a subgraph substitution on a graph with multiple inputs / output ops. We use a ResNet model, which has skip connections. This test checks that the substitution produces the expected number of ops, and also that the newly produced graph is still executable. """ graph, constants, _ = resnetv1.ResNet18(num_classes=10, input_resolution="small") model = Model(graph, constants) state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) y = model.apply(state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10)) subg = [ subgraph.SubgraphNode( op=new_op(op_name="subgraph/conv0", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["resnet11/skip/relu1"])), subgraph.SubgraphNode( op=new_op(op_name="subgraph/gelu1", op_type=OpType.GELU, input_names=["subgraph/conv0"]), output_names=["resnet_stride1_filtermul1_basic12/relu2"]) ] new_graph = subgraph.replace_subgraph(graph, subg) # the subgraph is 2 ops (conv / gelu) replacing 3 ops (conv / bn / relu) self.assertLen(graph.ops, len(new_graph.ops) + 1) new_model = Model(new_graph, constants) new_state = new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) y = new_model.apply(new_state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10))
def _synthesize(self, subg, props): synthesizer = ProgressiveSequentialSynthesizer( [(subg, props)], generation=0, mode=ProgressiveSequentialSynthesizer.Mode.WEIGHTED, max_len=3) subg = synthesizer.synthesize()[0] subg_spec = subg.subgraph for node in subg_spec: print(node.op.name) print(node.output_names) m = Model(subg.graph, self.constants) state = m.init(random.PRNGKey(0), self.input) out = m.apply(state, self.input)["fc/logits"] self.assertTrue((out != self.out).any())
def test_inference(self): graph, constants, _ = resnetv1.ResNet18(num_classes=10, input_resolution="small") model = Model(graph, constants) state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) self.assertLen(state, 2) self.assertIn("params", state) self.assertIn("batch_stats", state) out = model.apply(state, {"input": jnp.ones( (10, 32, 32, 3))})["fc/dense"] self.assertEqual(out.shape, (10, 10)) output_dict, new_state = model.apply( state, {"input": jnp.ones((10, 32, 32, 3))}, mutable=["batch_stats"]) self.assertEqual(output_dict["fc/dense"].shape, (10, 10)) self.assertIn("batch_stats", new_state)
def test_inference(self): graph, constants, _ = efficientnet.EfficientNetB0(num_classes=10) for op in graph.ops: print(f"name={op.name}") print(f"input_names={op.input_names}") print() model = Model(graph, constants) state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) self.assertLen(state, 2) self.assertIn("params", state) self.assertIn("batch_stats", state) inp = {"input": jnp.ones((10, 32, 32, 3))} out = model.apply(state, inp)["head/out"] self.assertEqual(out.shape, (10, 10)) output_dict, new_state = model.apply(state, inp, mutable=["batch_stats"]) self.assertEqual(output_dict["head/out"].shape, (10, 10)) self.assertIn("batch_stats", new_state)
class ModelTest(test.TestCase): def setUp(self): super().setUp() graph, constants, _ = cnn.CifarNet() self.cnn = Model(graph, constants) self.cnn_state = self.cnn.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) def test_cnn_inference(self): y = self.cnn.apply(self.cnn_state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10)) def test_cnn_inference_dict(self): out = self.cnn.apply(self.cnn_state, {"input": jnp.ones((10, 32, 32, 3))}) logits = out["fc/logits"] self.assertEqual(logits.shape, (10, 10)) def test_cnn_params(self): params = flax.core.unfreeze(self.cnn_state)["params"] param_count = parameter_overview.count_parameters(params) self.assertEqual(param_count, 2192458)
class SubgraphTest(test.TestCase): def setUp(self): super().setUp() self.graph, self.constants, _ = cnn.CifarNet() self.model = Model(self.graph, self.constants) self.state = self.model.init(random.PRNGKey(0), jnp.ones( (1, 32, 32, 3))) self.subgraph = [ subgraph.SubgraphNode(op=new_op( op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), ), subgraph.SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv/1" ]), output_names=["conv_layer1/relu"]) ] self.new_graph = subgraph.replace_subgraph(self.graph, self.subgraph) self.new_model = Model(self.new_graph, self.constants) self.new_state = self.new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) def test_subgraph_inserted(self): """Tests whether subgraph nodes were inserted.""" for node in self.subgraph: found = False for op in self.new_graph.ops: if op.name == node.op.name: found = True break self.assertTrue(found, f"Did not find {node.op.name} in new graph") def test_subgraph_execution(self): """Tests whether new graph can be executed.""" y = self.new_model.apply(self.new_state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10)) def test_subgraph_pruning(self): """Tests whether new graph was pruned of old nodes.""" new_params = flax.core.unfreeze(self.new_state)["params"] new_param_count = parameter_overview.count_parameters(new_params) params = flax.core.unfreeze(self.state)["params"] param_count = parameter_overview.count_parameters(params) self.assertLess(new_param_count, param_count) def test_weight_inheritance(self): """Tests weight inheritance.""" old_params = flax.core.unfreeze(self.state)["params"] new_params = flax.core.unfreeze(self.new_state)["params"] frozen_params, trainable_params = subgraph.inherit_params( new_params, old_params) self.assertLen(new_params, len(trainable_params) + len(frozen_params)) for param in ["fc/dense", "fc/logits", "conv_layer0/conv"]: assert param in frozen_params, f"expected param {param} to be frozen" self.assertIn("conv_layer1/conv/1", trainable_params, ("expected param layer1/conv/1 to be trainable")) def test_multi_input_output(self): """Tests a subgraph substitution on a graph with multiple inputs / output ops. We use a ResNet model, which has skip connections. This test checks that the substitution produces the expected number of ops, and also that the newly produced graph is still executable. """ graph, constants, _ = resnetv1.ResNet18(num_classes=10, input_resolution="small") model = Model(graph, constants) state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) y = model.apply(state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10)) subg = [ subgraph.SubgraphNode( op=new_op(op_name="subgraph/conv0", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["resnet11/skip/relu1"])), subgraph.SubgraphNode( op=new_op(op_name="subgraph/gelu1", op_type=OpType.GELU, input_names=["subgraph/conv0"]), output_names=["resnet_stride1_filtermul1_basic12/relu2"]) ] new_graph = subgraph.replace_subgraph(graph, subg) # the subgraph is 2 ops (conv / gelu) replacing 3 ops (conv / bn / relu) self.assertLen(graph.ops, len(new_graph.ops) + 1) new_model = Model(new_graph, constants) new_state = new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3))) y = new_model.apply(new_state, jnp.ones((10, 32, 32, 3))) self.assertEqual(y.shape, (10, 10))
class GraphSynthesizerTest(test.TestCase): def setUp(self): super().setUp() seed = int(time.time()) logging.info("Seed: %d", seed) py_rand.seed(seed) self.graph, self.constants, _ = cnn.CifarNet() self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)[self.graph.output_names[0]] self.max_size = int(10e8) def _synthesize(self, subg, props): ctr = functools.partial(PSS, generation=0, max_len=3, filter_progress=True) synthesizer = GraphSynthesizer([(subg, props)], sequential_ctr=ctr, generation=0) subg = synthesizer.synthesize()[0] subg_spec = subg.subgraph logging.info("synthesized...") for node in subg_spec: logging.info("%s", node.op.name) logging.info("%s", node.output_names) logging.info("") fingerprint_orig = fingerprint_graph(self.graph, self.constants, self.input) fingerprint_new = fingerprint_graph(subg.graph, self.constants, self.input) self.assertNotEqual(fingerprint_orig, fingerprint_new) def test_synthesizer_easy_one(self): """Replacing [conv3x3(features = 64)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp]) def test_synthesizer_easy_two(self): """Replacing [conv3x3(features = 64)].""" py_rand.seed(10) subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp]) def test_synthesizer_one(self): """Replacing [conv3x3(features = 64), ReLU].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:6]] subg[-1].output_names = self.graph.ops[6].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp]) def test_synthesizer_two(self): """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp]) def test_synthesizer_resnet_small(self): self.graph, self.constants, _ = resnetv1.ResNet18( num_classes=10, input_resolution="small") self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)[self.graph.output_names[0]] self.max_size = int(10e8) subg_ops = [self.graph.ops[4]] + self.graph.ops[9:11] subg = [subgraph.SubgraphNode(op=o) for o in subg_ops] subg[-1].output_names = [f"{subg[-1].op.name}:0"] subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp]) def test_synthesizer_resnet_big(self): self.graph, self.constants, _ = resnetv1.ResNet18( num_classes=10, input_resolution="small") self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)[self.graph.output_names[0]] self.max_size = int(10e8) subg_ops = self.graph.ops[3:5] + self.graph.ops[8:12] subg = [subgraph.SubgraphNode(op=o) for o in subg_ops] subg[-1].output_names = [f"{subg[-1].op.name}:0"] subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp])
class ProgSequentialTest(test.TestCase): def setUp(self): super().setUp() seed = int(time.time()) logging.info("Seed: %d", seed) py_rand.seed(seed) self.graph, self.constants, _ = cnn.CifarNet() self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)["fc/logits"] self.max_size = int(10e8) self.hard = False def _synthesize(self, subg, props): synthesizer = ProgressiveSequentialSynthesizer( [(subg, props)], generation=0, mode=ProgressiveSequentialSynthesizer.Mode.WEIGHTED, max_len=3) subg = synthesizer.synthesize()[0] subg_spec = subg.subgraph for node in subg_spec: print(node.op.name) print(node.output_names) m = Model(subg.graph, self.constants) state = m.init(random.PRNGKey(0), self.input) out = m.apply(state, self.input)["fc/logits"] self.assertTrue((out != self.out).any()) def test_synthesizer_easy_one(self): """Replacing [conv3x3(features = 64)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) # lp = linear.LinopProperty().infer(subgraph) self._synthesize(subgraph_model, [sp, dp]) def test_synthesizer_easy_two(self): """Replacing [conv3x3(features = 64)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp]) def test_synthesizer_one(self): """Replacing [conv3x3(features = 64), ReLU].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:6]] subg[-1].output_names = self.graph.ops[6].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) # lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp]) def test_synthesizer_two(self): """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)].""" subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) # dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp]) def test_synthesizer_hard(self): if not self.hard: return subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp])
class EnumSequentialTest(test.TestCase): def setUp(self): super().setUp() self.graph, self.constants, _ = cnn.CifarNet() self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)["fc/logits"] self.max_size = int(10e8) self.hard = False def test_sequence_generator(self): """Test the sequence_generator function for correctness. seq_generator should generate [[0], [1], [2], [0,0], [0,1], [0,2], [1,0], [1,1], [1,2], ..., [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ...,] """ def el_generator(): i = 0 while True: if i > 2: return yield i i += 1 seqs = list(sequence_generator(el_generator, 3)) self.assertLen(seqs, 3 + 3**2 + 3**3) for i in range(len(seqs)): if i < 3: self.assertEqual(seqs[i], [i]) elif i < 3 + 3**2: self.assertEqual(seqs[i], [(i - 3) // 3, i % 3]) else: self.assertEqual(seqs[i], [(i - 12) // 3 // 3, (i - 12) // 3 % 3, i % 3]) def test_kwargs_for_op_to_product(self): op_kwargs = {"a": [1, 2, 3], "b": [1, 2], "c": [5, 6]} input_kwargs = {"d": [1], "e": [1, 2], "f": [3, 4]} product = EnumerativeSequentialSynthesizer.kwargs_for_op_to_product( op_kwargs, input_kwargs) expected_length = 1 for _, v in op_kwargs.items(): expected_length *= len(v) for _, v in input_kwargs.items(): expected_length *= len(v) self.assertLen(product, expected_length) op_setting = {"a": 2, "b": 2, "c": 5} input_setting = {"d": 1, "e": 2, "f": 3} self.assertIn((op_setting, input_setting), product) def _synthesize(self, subg, props): synthesizer = EnumerativeSequentialSynthesizer([(subg, props)], 0, max_len=3) subg = synthesizer.synthesize()[0] m = Model(subg.graph, self.constants) state = m.init(random.PRNGKey(0), self.input) out = m.apply(state, self.input)["fc/logits"] self.assertTrue((out != self.out).any()) def test_synthesizer_easy_one(self): """Replacing [conv3x3(features = 64)]. Because we do not test linear, this is replaced by dense3x3(features = 64) due to the enumeration order. """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp]) def test_synthesizer_easy_two(self): """Replacing [conv3x3(features = 64)]. Because we test all three props, this is replaced by conv3x3(features = 64) (i.e., an identical op) due to the enumeration order. """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]] subg[-1].output_names = self.graph.ops[5].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp]) def test_synthesizer_one(self): """Replacing [conv3x3(features = 64), ReLU]. Because we do not check for the linear property, [dense(features = 64), ReLU] works as well (which is what is synthesized due to the enumeration order). """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:6]] subg[-1].output_names = self.graph.ops[6].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) # lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp]) def test_synthesizer_two(self): """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)]. Because we do not check for the depth property, [dense(features = 64), avgpool2x2(strides=2x2)] works as well (which is what is synthesized due to the enumeration order). """ subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, lp]) def test_synthesizer_hard(self): if not self.hard: return subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]] subg[-1].output_names = self.graph.ops[7].input_names subgraph_model = SubgraphModel(self.graph, self.constants, self.state, self.input, subg) sp = shape.ShapeProperty().infer(subgraph_model, max_size=self.max_size) dp = depth.DepthProperty().infer(subgraph_model) lp = linear.LinopProperty().infer(subgraph_model) self._synthesize(subgraph_model, [sp, dp, lp])