Esempio n. 1
0
  def _update_pairings(self, op, in_shapes):
    """Updates pairings with the property for a single op.

    Args:
      op: The op for which to infer the pairing property.
      in_shapes: the shapes of the input tensors.
    """
    assert len(op.input_names) == len(in_shapes)
    input_values = {
        input_name: jnp.ones(in_shape)
        for input_name, in_shape in zip(op.input_names, in_shapes)
    }
    output_names = [f"{op.name}:{i}" for i in range(op.num_outputs)]
    graph = new_graph(op.input_names, output_names, [op])
    model = Model(graph)
    state = model.init(jax.random.PRNGKey(0), input_values)
    pairings = GraphPairings.infer(
        model, input_values, state, abstract=False).pairings

    new_pairings = {}
    for output_idx, output_name in enumerate(output_names):
      new_pairings[output_idx] = {}
      for input_idx, input_name in enumerate(op.input_names):
        new_pairings[output_idx][input_idx] = pairings[output_name][input_name]

    key = self.hash(op, in_shapes)
    self.pairings[key] = new_pairings
Esempio n. 2
0
 def test_params_b1(self):
     graph, constants, _ = efficientnet.EfficientNetB1(num_classes=1000)
     model = Model(graph, constants)
     state = model.init(random.PRNGKey(0), jnp.ones((1, 240, 240, 3)))
     params = state["params"]
     num_params = compute_num_params(params)
     self.assertEqual(num_params, EFFICIENTNET_PARAMS["b1"])
Esempio n. 3
0
    def make_subgraph_models(self, subgraph_spec, graphs=None):
        """Inserts the new subgraph_spec into the subgraph_models.

    The graphs argument can be used to pass in intermediate results of synthesis
    rather than completing synthesis all at once, e.g., we can call
    make_subgraph_models twice and pass the output of the first call as an
    argument to the second call. This allows us to break the synthesis down into
    multiple steps.

    Args:
      subgraph_spec: The new ops to insert.
      graphs: The graphs into which to insert the new ops.

    Returns:
      A sequence of subgraphs with subgraph_spec inserted.
    """
        new_subgraph_models = []
        if not graphs:
            graphs = [None] * len(self.subgraphs_and_props)
        for graph, (subgraph_model, _) in zip(graphs,
                                              self.subgraphs_and_props):
            graph = graph if graph else subgraph_model.graph
            constants = subgraph_model.constants
            state = subgraph_model.state
            inputs = subgraph_model.inputs

            new_graph = subgraph.replace_subgraph(graph, subgraph_spec)
            if not self.abstract:
                # concrete synthesis initializes the state while inheriting the parent
                # params
                new_model = Model(new_graph, constants)
                try:
                    new_state = new_model.init(random.PRNGKey(0), inputs)
                except Exception as e:  # pylint: disable=broad-except
                    # catch everything else for now... this is the safest way to filter
                    # out malformed subgraphs which will not initialize
                    exc_type, exc_value, exc_traceback = sys.exc_info()
                    logging.info(
                        "%s", "".join(
                            traceback.format_exception(exc_type, exc_value,
                                                       exc_traceback)))
                    raise ValueError(
                        "Could not initialized malformed subgraph "
                        f"({type(e).__name__}: {e}).") from e

                new_state = flax.core.unfreeze(new_state)
                inherited, frozen = subgraph.inherit_params(
                    new_state["params"], state["params"])
                new_state = {**inherited, **frozen}
            else:
                new_state = None
            new_subgraph_model = SubgraphModel(new_graph, constants, new_state,
                                               inputs, subgraph_spec)
            new_subgraph_models.append(new_subgraph_model)
        return new_subgraph_models
    def _synthesize(self, subg, props):
        synthesizer = EnumerativeSequentialSynthesizer([(subg, props)],
                                                       0,
                                                       max_len=3)

        subg = synthesizer.synthesize()[0]

        m = Model(subg.graph, self.constants)
        state = m.init(random.PRNGKey(0), self.input)
        out = m.apply(state, self.input)["fc/logits"]
        self.assertTrue((out != self.out).any())
Esempio n. 5
0
    def test_multi_input_output(self):
        """Tests a subgraph substitution on a graph with multiple inputs / output ops.

    We use a ResNet model, which has skip connections. This test checks that the
    substitution produces the expected number of ops, and also that the newly
    produced graph is still executable.
    """

        graph, constants, _ = resnetv1.ResNet18(num_classes=10,
                                                input_resolution="small")
        model = Model(graph, constants)
        state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
        y = model.apply(state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))

        subg = [
            subgraph.SubgraphNode(
                op=new_op(op_name="subgraph/conv0",
                          op_type=OpType.CONV,
                          op_kwargs={
                              "features": 64,
                              "kernel_size": [1, 1]
                          },
                          input_names=["resnet11/skip/relu1"])),
            subgraph.SubgraphNode(
                op=new_op(op_name="subgraph/gelu1",
                          op_type=OpType.GELU,
                          input_names=["subgraph/conv0"]),
                output_names=["resnet_stride1_filtermul1_basic12/relu2"])
        ]
        new_graph = subgraph.replace_subgraph(graph, subg)

        # the subgraph is 2 ops (conv / gelu) replacing 3 ops (conv / bn / relu)
        self.assertLen(graph.ops, len(new_graph.ops) + 1)

        new_model = Model(new_graph, constants)
        new_state = new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))

        y = new_model.apply(new_state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))
    def _synthesize(self, subg, props):
        synthesizer = ProgressiveSequentialSynthesizer(
            [(subg, props)],
            generation=0,
            mode=ProgressiveSequentialSynthesizer.Mode.WEIGHTED,
            max_len=3)

        subg = synthesizer.synthesize()[0]
        subg_spec = subg.subgraph
        for node in subg_spec:
            print(node.op.name)
            print(node.output_names)

        m = Model(subg.graph, self.constants)
        state = m.init(random.PRNGKey(0), self.input)
        out = m.apply(state, self.input)["fc/logits"]
        self.assertTrue((out != self.out).any())
Esempio n. 7
0
    def test_inference(self):
        graph, constants, _ = resnetv1.ResNet18(num_classes=10,
                                                input_resolution="small")
        model = Model(graph, constants)
        state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
        self.assertLen(state, 2)
        self.assertIn("params", state)
        self.assertIn("batch_stats", state)

        out = model.apply(state, {"input": jnp.ones(
            (10, 32, 32, 3))})["fc/dense"]
        self.assertEqual(out.shape, (10, 10))

        output_dict, new_state = model.apply(
            state, {"input": jnp.ones((10, 32, 32, 3))},
            mutable=["batch_stats"])
        self.assertEqual(output_dict["fc/dense"].shape, (10, 10))
        self.assertIn("batch_stats", new_state)
Esempio n. 8
0
    def test_inference(self):
        graph, constants, _ = efficientnet.EfficientNetB0(num_classes=10)
        for op in graph.ops:
            print(f"name={op.name}")
            print(f"input_names={op.input_names}")
        print()
        model = Model(graph, constants)
        state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
        self.assertLen(state, 2)
        self.assertIn("params", state)
        self.assertIn("batch_stats", state)

        inp = {"input": jnp.ones((10, 32, 32, 3))}

        out = model.apply(state, inp)["head/out"]
        self.assertEqual(out.shape, (10, 10))

        output_dict, new_state = model.apply(state,
                                             inp,
                                             mutable=["batch_stats"])
        self.assertEqual(output_dict["head/out"].shape, (10, 10))
        self.assertIn("batch_stats", new_state)
Esempio n. 9
0
class ModelTest(test.TestCase):
    def setUp(self):
        super().setUp()

        graph, constants, _ = cnn.CifarNet()
        self.cnn = Model(graph, constants)
        self.cnn_state = self.cnn.init(random.PRNGKey(0),
                                       jnp.ones((1, 32, 32, 3)))

    def test_cnn_inference(self):
        y = self.cnn.apply(self.cnn_state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))

    def test_cnn_inference_dict(self):
        out = self.cnn.apply(self.cnn_state,
                             {"input": jnp.ones((10, 32, 32, 3))})
        logits = out["fc/logits"]
        self.assertEqual(logits.shape, (10, 10))

    def test_cnn_params(self):
        params = flax.core.unfreeze(self.cnn_state)["params"]
        param_count = parameter_overview.count_parameters(params)
        self.assertEqual(param_count, 2192458)
Esempio n. 10
0
class SubgraphTest(test.TestCase):
    def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        self.model = Model(self.graph, self.constants)
        self.state = self.model.init(random.PRNGKey(0), jnp.ones(
            (1, 32, 32, 3)))

        self.subgraph = [
            subgraph.SubgraphNode(op=new_op(
                op_name="conv_layer1/conv/1",
                op_type=OpType.CONV,
                op_kwargs={
                    "features": 64,
                    "kernel_size": [1, 1]
                },
                input_names=["conv_layer0/avg_pool"]), ),
            subgraph.SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1",
                                            op_type=OpType.GELU,
                                            input_names=["conv_layer1/conv/1"
                                                         ]),
                                  output_names=["conv_layer1/relu"])
        ]
        self.new_graph = subgraph.replace_subgraph(self.graph, self.subgraph)
        self.new_model = Model(self.new_graph, self.constants)
        self.new_state = self.new_model.init(random.PRNGKey(0),
                                             jnp.ones((1, 32, 32, 3)))

    def test_subgraph_inserted(self):
        """Tests whether subgraph nodes were inserted."""
        for node in self.subgraph:
            found = False
            for op in self.new_graph.ops:
                if op.name == node.op.name:
                    found = True
                    break
            self.assertTrue(found, f"Did not find {node.op.name} in new graph")

    def test_subgraph_execution(self):
        """Tests whether new graph can be executed."""
        y = self.new_model.apply(self.new_state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))

    def test_subgraph_pruning(self):
        """Tests whether new graph was pruned of old nodes."""
        new_params = flax.core.unfreeze(self.new_state)["params"]
        new_param_count = parameter_overview.count_parameters(new_params)

        params = flax.core.unfreeze(self.state)["params"]
        param_count = parameter_overview.count_parameters(params)
        self.assertLess(new_param_count, param_count)

    def test_weight_inheritance(self):
        """Tests weight inheritance."""
        old_params = flax.core.unfreeze(self.state)["params"]
        new_params = flax.core.unfreeze(self.new_state)["params"]

        frozen_params, trainable_params = subgraph.inherit_params(
            new_params, old_params)

        self.assertLen(new_params, len(trainable_params) + len(frozen_params))

        for param in ["fc/dense", "fc/logits", "conv_layer0/conv"]:
            assert param in frozen_params, f"expected param {param} to be frozen"
        self.assertIn("conv_layer1/conv/1", trainable_params,
                      ("expected param layer1/conv/1 to be trainable"))

    def test_multi_input_output(self):
        """Tests a subgraph substitution on a graph with multiple inputs / output ops.

    We use a ResNet model, which has skip connections. This test checks that the
    substitution produces the expected number of ops, and also that the newly
    produced graph is still executable.
    """

        graph, constants, _ = resnetv1.ResNet18(num_classes=10,
                                                input_resolution="small")
        model = Model(graph, constants)
        state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
        y = model.apply(state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))

        subg = [
            subgraph.SubgraphNode(
                op=new_op(op_name="subgraph/conv0",
                          op_type=OpType.CONV,
                          op_kwargs={
                              "features": 64,
                              "kernel_size": [1, 1]
                          },
                          input_names=["resnet11/skip/relu1"])),
            subgraph.SubgraphNode(
                op=new_op(op_name="subgraph/gelu1",
                          op_type=OpType.GELU,
                          input_names=["subgraph/conv0"]),
                output_names=["resnet_stride1_filtermul1_basic12/relu2"])
        ]
        new_graph = subgraph.replace_subgraph(graph, subg)

        # the subgraph is 2 ops (conv / gelu) replacing 3 ops (conv / bn / relu)
        self.assertLen(graph.ops, len(new_graph.ops) + 1)

        new_model = Model(new_graph, constants)
        new_state = new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))

        y = new_model.apply(new_state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))
Esempio n. 11
0
class GraphSynthesizerTest(test.TestCase):
    def setUp(self):
        super().setUp()

        seed = int(time.time())
        logging.info("Seed: %d", seed)
        py_rand.seed(seed)

        self.graph, self.constants, _ = cnn.CifarNet()
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state,
                                self.input)[self.graph.output_names[0]]
        self.max_size = int(10e8)

    def _synthesize(self, subg, props):
        ctr = functools.partial(PSS,
                                generation=0,
                                max_len=3,
                                filter_progress=True)
        synthesizer = GraphSynthesizer([(subg, props)],
                                       sequential_ctr=ctr,
                                       generation=0)

        subg = synthesizer.synthesize()[0]
        subg_spec = subg.subgraph
        logging.info("synthesized...")
        for node in subg_spec:
            logging.info("%s", node.op.name)
            logging.info("%s", node.output_names)
        logging.info("")

        fingerprint_orig = fingerprint_graph(self.graph, self.constants,
                                             self.input)
        fingerprint_new = fingerprint_graph(subg.graph, self.constants,
                                            self.input)
        self.assertNotEqual(fingerprint_orig, fingerprint_new)

    def test_synthesizer_easy_one(self):
        """Replacing [conv3x3(features = 64)]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp])

    def test_synthesizer_easy_two(self):
        """Replacing [conv3x3(features = 64)]."""
        py_rand.seed(10)
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])

    def test_synthesizer_one(self):
        """Replacing [conv3x3(features = 64), ReLU]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:6]]
        subg[-1].output_names = self.graph.ops[6].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp])

    def test_synthesizer_two(self):
        """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])

    def test_synthesizer_resnet_small(self):
        self.graph, self.constants, _ = resnetv1.ResNet18(
            num_classes=10, input_resolution="small")
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state,
                                self.input)[self.graph.output_names[0]]
        self.max_size = int(10e8)

        subg_ops = [self.graph.ops[4]] + self.graph.ops[9:11]
        subg = [subgraph.SubgraphNode(op=o) for o in subg_ops]
        subg[-1].output_names = [f"{subg[-1].op.name}:0"]
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])

    def test_synthesizer_resnet_big(self):
        self.graph, self.constants, _ = resnetv1.ResNet18(
            num_classes=10, input_resolution="small")
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state,
                                self.input)[self.graph.output_names[0]]
        self.max_size = int(10e8)

        subg_ops = self.graph.ops[3:5] + self.graph.ops[8:12]
        subg = [subgraph.SubgraphNode(op=o) for o in subg_ops]
        subg[-1].output_names = [f"{subg[-1].op.name}:0"]
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])
class ProgSequentialTest(test.TestCase):
    def setUp(self):
        super().setUp()

        seed = int(time.time())
        logging.info("Seed: %d", seed)
        py_rand.seed(seed)

        self.graph, self.constants, _ = cnn.CifarNet()
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state, self.input)["fc/logits"]
        self.max_size = int(10e8)

        self.hard = False

    def _synthesize(self, subg, props):
        synthesizer = ProgressiveSequentialSynthesizer(
            [(subg, props)],
            generation=0,
            mode=ProgressiveSequentialSynthesizer.Mode.WEIGHTED,
            max_len=3)

        subg = synthesizer.synthesize()[0]
        subg_spec = subg.subgraph
        for node in subg_spec:
            print(node.op.name)
            print(node.output_names)

        m = Model(subg.graph, self.constants)
        state = m.init(random.PRNGKey(0), self.input)
        out = m.apply(state, self.input)["fc/logits"]
        self.assertTrue((out != self.out).any())

    def test_synthesizer_easy_one(self):
        """Replacing [conv3x3(features = 64)]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        # lp = linear.LinopProperty().infer(subgraph)
        self._synthesize(subgraph_model, [sp, dp])

    def test_synthesizer_easy_two(self):
        """Replacing [conv3x3(features = 64)]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])

    def test_synthesizer_one(self):
        """Replacing [conv3x3(features = 64), ReLU]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:6]]
        subg[-1].output_names = self.graph.ops[6].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        # lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp])

    def test_synthesizer_two(self):
        """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        # dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])

    def test_synthesizer_hard(self):
        if not self.hard:
            return
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])
class EnumSequentialTest(test.TestCase):
    def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state, self.input)["fc/logits"]
        self.max_size = int(10e8)

        self.hard = False

    def test_sequence_generator(self):
        """Test the sequence_generator function for correctness.

    seq_generator should generate
      [[0], [1], [2],
       [0,0], [0,1], [0,2], [1,0], [1,1], [1,2], ...,
       [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ...,]
    """
        def el_generator():
            i = 0
            while True:
                if i > 2:
                    return
                yield i
                i += 1

        seqs = list(sequence_generator(el_generator, 3))
        self.assertLen(seqs, 3 + 3**2 + 3**3)
        for i in range(len(seqs)):
            if i < 3:
                self.assertEqual(seqs[i], [i])
            elif i < 3 + 3**2:
                self.assertEqual(seqs[i], [(i - 3) // 3, i % 3])
            else:
                self.assertEqual(seqs[i], [(i - 12) // 3 // 3,
                                           (i - 12) // 3 % 3, i % 3])

    def test_kwargs_for_op_to_product(self):
        op_kwargs = {"a": [1, 2, 3], "b": [1, 2], "c": [5, 6]}
        input_kwargs = {"d": [1], "e": [1, 2], "f": [3, 4]}
        product = EnumerativeSequentialSynthesizer.kwargs_for_op_to_product(
            op_kwargs, input_kwargs)
        expected_length = 1
        for _, v in op_kwargs.items():
            expected_length *= len(v)
        for _, v in input_kwargs.items():
            expected_length *= len(v)
        self.assertLen(product, expected_length)

        op_setting = {"a": 2, "b": 2, "c": 5}
        input_setting = {"d": 1, "e": 2, "f": 3}
        self.assertIn((op_setting, input_setting), product)

    def _synthesize(self, subg, props):
        synthesizer = EnumerativeSequentialSynthesizer([(subg, props)],
                                                       0,
                                                       max_len=3)

        subg = synthesizer.synthesize()[0]

        m = Model(subg.graph, self.constants)
        state = m.init(random.PRNGKey(0), self.input)
        out = m.apply(state, self.input)["fc/logits"]
        self.assertTrue((out != self.out).any())

    def test_synthesizer_easy_one(self):
        """Replacing [conv3x3(features = 64)].

    Because we do not test linear, this is replaced by dense3x3(features = 64)
    due to the enumeration order.
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp])

    def test_synthesizer_easy_two(self):
        """Replacing [conv3x3(features = 64)].

    Because we test all three props, this is replaced by conv3x3(features = 64)
    (i.e., an identical op) due to the enumeration order.
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])

    def test_synthesizer_one(self):
        """Replacing [conv3x3(features = 64), ReLU].

    Because we do not check for the linear property, [dense(features = 64),
    ReLU] works as well (which is what is synthesized due to the enumeration
    order).
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:6]]
        subg[-1].output_names = self.graph.ops[6].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        # lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp])

    def test_synthesizer_two(self):
        """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)].

    Because we do not check for the depth property, [dense(features = 64),
    avgpool2x2(strides=2x2)] works as well (which is what is synthesized due to
    the enumeration order).
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])

    def test_synthesizer_hard(self):
        if not self.hard:
            return
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])