Beispiel #1
0
    def setUp(self):
        super().setUp()

        self.graph_dense = new_graph(
            ["input"],
            ["output"],
            [
                new_op(
                    op_name="output",
                    op_type=OpType.SOFTMAX,
                    # op_kwargs={"features": 10},
                    input_names=["input"])
            ])
        state_dense = Model(self.graph_dense).init(
            random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))})
        self.subgraph_dense = SubgraphModel(self.graph_dense, None,
                                            state_dense,
                                            {"input": jnp.ones((5, 5, 5))})
        self.lp_dense = linear.LinopProperty().infer(self.subgraph_dense)

        self.graph_conv = new_graph(["input"], ["output"], [
            new_op(op_name="output",
                   op_type=OpType.CONV,
                   op_kwargs={
                       "features": 10,
                       "kernel_size": [3, 3]
                   },
                   input_names=["input"])
        ])
        state_conv = Model(self.graph_conv).init(
            random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))})
        self.subgraph_conv = SubgraphModel(self.graph_conv, None, state_conv,
                                           {"input": jnp.ones((5, 5, 5))})
        self.lp_conv = linear.LinopProperty().infer(self.subgraph_conv)
 def test_abstract_sequential_synthesizer_output_features(self):
   graph, constants, _ = cnn.CifarNet()
   subgraph_spec = [
       SubgraphNode(
           op=new_op(
               op_name="conv_layer1/conv",
               op_type=OpType.CONV,
               op_kwargs={
                   "features": "S:-1*2",
                   "kernel_size": [1, 1]
               },
               input_names=["conv_layer0/avg_pool"]),),
       SubgraphNode(
           op=new_op(
               op_name="conv_layer1/relu",
               op_type=OpType.RELU,
               input_names=["conv_layer1/conv"]),
           output_names=["conv_layer1/relu"])
   ]
   subgraph = replace_subgraph(graph, subgraph_spec)
   subgraph_model = SubgraphModel(subgraph, constants, None,
                                  {"input": jnp.zeros((5, 32, 32, 10))},
                                  subgraph_spec)
   sp = shape.ShapeProperty().infer(subgraph_model)
   syn = TestSequentialSynthesizer([(subgraph_model, [sp])], 0)
   self.assertEqual(syn.output_features_mul, 2)
   self.assertEqual(syn.output_features_div, 1)
 def test_abstract_sequential_synthesizer_fail(self):
   graph, constants, _ = cnn.CifarNet()
   subgraph_spec = [
       SubgraphNode(
           op=new_op(
               op_name="conv_layer1/conv/1",
               op_type=OpType.CONV,
               op_kwargs={
                   "features": 64,
                   "kernel_size": [1, 1]
               },
               input_names=["conv_layer0/avg_pool"]),
           output_names=["conv_layer1/conv"]),
       SubgraphNode(
           op=new_op(
               op_name="conv_layer1/gelu/1",
               op_type=OpType.GELU,
               input_names=["conv_layer1/conv"]),
           output_names=["conv_layer1/relu"])
   ]
   subgraph = SubgraphModel(graph, constants, None,
                            {"input": jnp.zeros((5, 32, 32, 10))},
                            subgraph_spec)
   self.assertRaisesRegex(ValueError, ".*exactly one input.*",
                          TestSequentialSynthesizer, [(subgraph, [])], 0)
Beispiel #4
0
    def test_rewire(self):
        # orig: conv, relu, pool, conv, relu, pool, flatten, dense, relu, dense
        # new:  conv, relu, pool, conv, gelu, pool, flatten, dense, relu, dense
        subgraph_spec = [
            SubgraphNode(op=new_op(op_name="conv_layer1/conv/1",
                                   op_type=OpType.CONV,
                                   op_kwargs={
                                       "features": 64,
                                       "kernel_size": [1, 1]
                                   },
                                   input_names=["conv_layer0/avg_pool"]), ),
            SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1",
                                   op_type=OpType.GELU,
                                   input_names=["conv_layer1/conv/1"]),
                         output_names=["conv_layer1/relu"])
        ]

        graph = replace_subgraph(self.graph, subgraph_spec)
        subgraph_model = SubgraphModel(graph, self.constants, {}, {},
                                       subgraph_spec)
        dp = depth.DepthProperty().infer(subgraph_model)

        depth_map = dp.depth_map
        self.assertLen(depth_map, 1)
        self.assertIn("conv_layer0/avg_pool:0", depth_map)
        self.assertLen(depth_map["conv_layer0/avg_pool:0"], 2)
        self.assertIn("conv_layer1/relu:0",
                      depth_map["conv_layer0/avg_pool:0"])
        self.assertEqual(
            depth_map["conv_layer0/avg_pool:0"]["conv_layer1/relu:0"], 1)
        self.assertIn("conv_layer1/gelu/1:0",
                      depth_map["conv_layer0/avg_pool:0"])
        self.assertEqual(
            depth_map["conv_layer0/avg_pool:0"]["conv_layer1/gelu/1:0"], 1)
Beispiel #5
0
    def test_rewire(self):
        subgraph_spec = [
            SubgraphNode(op=new_op(op_name="conv_layer1/conv/1",
                                   op_type=OpType.CONV,
                                   op_kwargs={
                                       "features": 64,
                                       "kernel_size": [1, 1]
                                   },
                                   input_names=["conv_layer0/avg_pool"]), ),
            SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1",
                                   op_type=OpType.GELU,
                                   input_names=["conv_layer1/conv/1"]),
                         output_names=["conv_layer1/relu"])
        ]

        graph = replace_subgraph(self.graph, subgraph_spec)
        state = Model(graph,
                      self.constants).init(random.PRNGKey(0),
                                           {"input": jnp.ones((5, 32, 32, 3))})
        subgraph_model = SubgraphModel(graph, self.constants, state,
                                       {"input": jnp.ones(
                                           (5, 32, 32, 3))}, subgraph_spec)
        sp = shape.ShapeProperty().infer(subgraph_model)

        self.assertLen(sp.input_shapes, 1)
        self.assertIn("conv_layer0/avg_pool:0", sp.input_shapes)
        self.assertLen(sp.output_shapes, 2)
        self.assertIn("conv_layer1/gelu/1:0", sp.output_shapes)
        self.assertIn("conv_layer1/relu:0", sp.output_shapes)
Beispiel #6
0
 def test_unsatisfy(self):
     # This test removes the last dense layer, so the new graph should be less
     # deep.
     graph = new_graph(input_names=["input"],
                       output_names=["fc/relu"],
                       ops=self.graph.ops)
     subgraph_model = SubgraphModel(graph, self.constants, {}, {})
     self.assertFalse(self.dp.verify(subgraph_model))
Beispiel #7
0
    def test_abstract(self):
        graphs = []

        conv_op = functools.partial(new_op,
                                    op_type=OpType.CONV,
                                    op_kwargs={
                                        "features": 10,
                                        "kernel_size": [3, 3]
                                    })
        dense_op = functools.partial(new_op,
                                     op_type=OpType.DENSE,
                                     op_kwargs={
                                         "features": 10,
                                     })

        for op_type in [
                OpType.RELU, OpType.SOFTMAX, OpType.LAYER_NORM,
                OpType.BATCH_NORM
        ]:
            for op_ctr in [conv_op, dense_op]:
                graphs.append([
                    op_ctr(input_names=["input"], op_name="other"),
                    new_op(op_name="output",
                           op_type=op_type,
                           input_names=["other"])
                ])
                graphs.append([
                    new_op(op_name="other",
                           op_type=op_type,
                           input_names=["input"]),
                    op_ctr(input_names=["other"], op_name="output"),
                ])

        input_tensor = {"input": jnp.ones((5, 5, 5, 5))}
        for graph in graphs:
            graph = new_graph(["input"], ["output"], graph)
            state = Model(graph).init(random.PRNGKey(1), input_tensor)

            # Make all the kernels positive, otherwise, the ReLU might zero out the
            # entire tensor.
            state = jax.tree_util.tree_map(abs, state)

            subg_model = SubgraphModel(graph, None, state, input_tensor)
            lp_abstract = linear.LinopProperty().infer(subg_model,
                                                       abstract=True)
            lp_concrete = linear.LinopProperty().infer(subg_model,
                                                       abstract=False)
            pairings_concerete = lp_concrete.pairings["output"][
                "input"].mappings
            pairings_abstract = lp_abstract.pairings["output"][
                "input"].mappings

            print("concrete:", pairings_concerete)
            print("abstract:", pairings_abstract)
            self.assertTrue(
                ((pairings_abstract - pairings_concerete) == 0).all())
Beispiel #8
0
 def test_synthesizer_easy_one(self):
     """Replacing [conv3x3(features = 64)]."""
     subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
     subg[-1].output_names = self.graph.ops[5].input_names
     subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                    self.input, subg)
     sp = shape.ShapeProperty().infer(subgraph_model,
                                      max_size=self.max_size)
     dp = depth.DepthProperty().infer(subgraph_model)
     self._synthesize(subgraph_model, [sp, dp])
Beispiel #9
0
 def test_synthesizer_two(self):
     """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)]."""
     subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
     subg[-1].output_names = self.graph.ops[7].input_names
     subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                    self.input, subg)
     sp = shape.ShapeProperty().infer(subgraph_model,
                                      max_size=self.max_size)
     lp = linear.LinopProperty().infer(subgraph_model)
     self._synthesize(subgraph_model, [sp, lp])
Beispiel #10
0
    def make_subgraph_models(self, subgraph_spec, graphs=None):
        """Inserts the new subgraph_spec into the subgraph_models.

    The graphs argument can be used to pass in intermediate results of synthesis
    rather than completing synthesis all at once, e.g., we can call
    make_subgraph_models twice and pass the output of the first call as an
    argument to the second call. This allows us to break the synthesis down into
    multiple steps.

    Args:
      subgraph_spec: The new ops to insert.
      graphs: The graphs into which to insert the new ops.

    Returns:
      A sequence of subgraphs with subgraph_spec inserted.
    """
        new_subgraph_models = []
        if not graphs:
            graphs = [None] * len(self.subgraphs_and_props)
        for graph, (subgraph_model, _) in zip(graphs,
                                              self.subgraphs_and_props):
            graph = graph if graph else subgraph_model.graph
            constants = subgraph_model.constants
            state = subgraph_model.state
            inputs = subgraph_model.inputs

            new_graph = subgraph.replace_subgraph(graph, subgraph_spec)
            if not self.abstract:
                # concrete synthesis initializes the state while inheriting the parent
                # params
                new_model = Model(new_graph, constants)
                try:
                    new_state = new_model.init(random.PRNGKey(0), inputs)
                except Exception as e:  # pylint: disable=broad-except
                    # catch everything else for now... this is the safest way to filter
                    # out malformed subgraphs which will not initialize
                    exc_type, exc_value, exc_traceback = sys.exc_info()
                    logging.info(
                        "%s", "".join(
                            traceback.format_exception(exc_type, exc_value,
                                                       exc_traceback)))
                    raise ValueError(
                        "Could not initialized malformed subgraph "
                        f"({type(e).__name__}: {e}).") from e

                new_state = flax.core.unfreeze(new_state)
                inherited, frozen = subgraph.inherit_params(
                    new_state["params"], state["params"])
                new_state = {**inherited, **frozen}
            else:
                new_state = None
            new_subgraph_model = SubgraphModel(new_graph, constants, new_state,
                                               inputs, subgraph_spec)
            new_subgraph_models.append(new_subgraph_model)
        return new_subgraph_models
Beispiel #11
0
 def test_satisfy(self):
     # This test removes the last dense layer, so the old graph should be more
     # deep.
     ops = self.graph.ops[:-1]
     ops[-1].name = "fc/logits"
     graph = new_graph(input_names=["input"],
                       output_names=["fc/logits"],
                       ops=ops)
     subgraph_model = SubgraphModel(graph, self.constants, {}, {})
     dp = depth.DepthProperty().infer(subgraph_model)
     self.assertTrue(dp.verify(self.subgraph_model))
Beispiel #12
0
    def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        state = Model(self.graph,
                      self.constants).init(random.PRNGKey(0),
                                           {"input": jnp.ones((5, 32, 32, 3))})
        self.subgraph_model = SubgraphModel(
            self.graph, self.constants, state,
            {"input": jnp.ones((5, 32, 32, 3))})
        self.sp = shape.ShapeProperty().infer(self.subgraph_model)
 def test_synthesizer_hard(self):
     if not self.hard:
         return
     subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
     subg[-1].output_names = self.graph.ops[7].input_names
     subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                    self.input, subg)
     sp = shape.ShapeProperty().infer(subgraph_model,
                                      max_size=self.max_size)
     dp = depth.DepthProperty().infer(subgraph_model)
     lp = linear.LinopProperty().infer(subgraph_model)
     self._synthesize(subgraph_model, [sp, dp, lp])
Beispiel #14
0
 def test_unsatisfy(self):
     # This test removes the last dense layer, so the new graph should have a
     # different shape (and therefore not satisfy the inferred shape property).
     graph = new_graph(input_names=["input"],
                       output_names=["fc/relu"],
                       ops=self.graph.ops)
     state = Model(graph,
                   self.constants).init(random.PRNGKey(0),
                                        {"input": jnp.ones((5, 32, 32, 3))})
     subgraph_model = SubgraphModel(graph, self.constants, state,
                                    {"input": jnp.ones((5, 32, 32, 3))})
     self.assertFalse(self.sp.verify(subgraph_model))
    def test_synthesizer_easy_one(self):
        """Replacing [conv3x3(features = 64)].

    Because we do not test linear, this is replaced by dense3x3(features = 64)
    due to the enumeration order.
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp])
    def test_synthesizer_two(self):
        """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)].

    Because we do not check for the depth property, [dense(features = 64),
    avgpool2x2(strides=2x2)] works as well (which is what is synthesized due to
    the enumeration order).
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])
    def test_synthesizer_easy_two(self):
        """Replacing [conv3x3(features = 64)].

    Because we test all three props, this is replaced by conv3x3(features = 64)
    (i.e., an identical op) due to the enumeration order.
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])
Beispiel #18
0
    def test_synthesizer_resnet_big(self):
        self.graph, self.constants, _ = resnetv1.ResNet18(
            num_classes=10, input_resolution="small")
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state,
                                self.input)[self.graph.output_names[0]]
        self.max_size = int(10e8)

        subg_ops = self.graph.ops[3:5] + self.graph.ops[8:12]
        subg = [subgraph.SubgraphNode(op=o) for o in subg_ops]
        subg[-1].output_names = [f"{subg[-1].op.name}:0"]
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])
Beispiel #19
0
    def mutate(self, graph, contexts, abstract=True):
        """Selects a subgraph and mutates its abstract properties.

    Note that this method does not mutate the graph, only the properties of a
    subgraph. Therefore the caller is responsible for synthesizing a subgraph
    satisfying the mutated properties.

    Args:
      graph: The input graph to mutate.
      contexts: A list of (constants, state, inputs) tuples, each specifying
        a different instantiation of the graph.
      abstract: Whether to infer the subgraph properties abstractly or
        concretely.

    Returns:
      For each instantiation, a subgraph model specifying the selected subgraph,
      and a sequence of abstract properties specifying the mutates properties.

    Raises:
      NotImplementedError: for concrete inference of properties.
    """

        subgraph_spec = self.select_subgraph(graph)
        new_graph = replace_subgraph(graph, subgraph_spec)
        if not abstract:
            # need to initialize the model state if not abstract
            raise NotImplementedError

        models_and_props = []
        to_mutate = random.randrange(len(contexts))
        for idx, (constants, _, inputs) in enumerate(contexts):
            subgraph_model = SubgraphModel(new_graph,
                                           constants,
                                           state=None,
                                           inputs=inputs,
                                           subgraph=subgraph_spec)
            # Only mutate the properties of one randomly selected instance. The other
            # instances simply need to satisfy the shape property (which is never
            # mutated).
            # The alternative would be to make sure that all the properties are
            # mutated "in the same way" for every instance of the graph, but that
            # requires overly complex logic matching input and output names.
            if idx == to_mutate:
                new_properties = [
                    prop.infer(subgraph_model, abstract=abstract).mutate()
                    for prop in self.properties
                ]
            else:
                new_properties = []
                for prop in self.properties:
                    if type(prop) is ShapeProperty:  # pylint: disable=unidiomatic-typecheck
                        new_properties.append(ShapeProperty().infer(
                            subgraph_model, abstract=abstract))
            models_and_props.append((subgraph_model, new_properties))

        # match mutated shapes
        for prop in models_and_props[to_mutate][1]:
            if type(prop) is not ShapeProperty:
                continue  # pylint: disable=unidiomatic-typecheck
            output_shapes = prop.output_shapes
            for _, other_props in models_and_props:
                for other_prop in other_props:
                    if type(other_prop) is not ShapeProperty:
                        continue  # pylint: disable=unidiomatic-typecheck
                    for k in list(other_prop.output_shapes.keys()):
                        if k not in output_shapes:
                            del other_prop.output_shapes[k]
        return models_and_props
Beispiel #20
0
 def test_full(self):
     graph, constants, _ = cnn.CifarNet()
     subgraph_model = SubgraphModel(graph, constants, None,
                                    {"input": jnp.ones((10, 32, 32, 3))})
     linear.LinopProperty().infer(subgraph_model)
Beispiel #21
0
    def test_multi_input(self):
        ops = [
            new_op(op_name="dense0",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu0",
                   op_type=OpType.RELU,
                   input_names=["dense0"]),
            new_op(op_name="dense1",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu1",
                   op_type=OpType.RELU,
                   input_names=["dense1"]),
            new_op(op_name="dense2",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu2",
                   op_type=OpType.RELU,
                   input_names=["dense2"]),
            new_op(op_name="add0",
                   op_type=OpType.ADD,
                   input_names=["relu0", "relu1"]),
            new_op(op_name="add1",
                   op_type=OpType.ADD,
                   input_names=["relu1", "relu2"]),
        ]
        graph = new_graph(input_names=["input"],
                          output_names=["add0", "add1"],
                          ops=ops)
        subgraph_spec = [
            SubgraphNode(op=new_op(
                op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])),
            SubgraphNode(op=new_op(
                op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])),
            SubgraphNode(op=new_op(
                op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])),
            SubgraphNode(op=new_op(op_name="add0",
                                   op_type=OpType.ADD,
                                   input_names=["relu0", "relu1"]),
                         output_names=["add0"]),
            SubgraphNode(op=new_op(op_name="add1",
                                   op_type=OpType.ADD,
                                   input_names=["relu1", "relu2"]),
                         output_names=["add1"]),
        ]
        replaced_graph = replace_subgraph(graph, subgraph_spec)
        subgraph_model = SubgraphModel(replaced_graph, {}, {}, {},
                                       subgraph_spec)
        dp = depth.DepthProperty().infer(subgraph_model)
        depth_map = dp.depth_map

        self.assertLen(depth_map, 3)
        self.assertIn("dense0:0", depth_map)
        self.assertIn("dense1:0", depth_map)
        self.assertIn("dense2:0", depth_map)
        self.assertLen(depth_map["dense0:0"], 1)
        self.assertEqual(depth_map["dense0:0"]["add0:0"], 2)
        self.assertLen(depth_map["dense1:0"], 2)
        self.assertEqual(depth_map["dense1:0"]["add0:0"], 2)
        self.assertEqual(depth_map["dense1:0"]["add1:0"], 2)
        self.assertLen(depth_map["dense2:0"], 1)
        self.assertEqual(depth_map["dense2:0"]["add1:0"], 2)
Beispiel #22
0
    def test_multi_input(self):
        ops = [
            new_op(op_name="dense0",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu0",
                   op_type=OpType.RELU,
                   input_names=["dense0"]),
            new_op(op_name="dense1",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu1",
                   op_type=OpType.RELU,
                   input_names=["dense1"]),
            new_op(op_name="dense2",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu2",
                   op_type=OpType.RELU,
                   input_names=["dense2"]),
            new_op(op_name="add0",
                   op_type=OpType.ADD,
                   input_names=["relu0", "relu1"]),
            new_op(op_name="add1",
                   op_type=OpType.ADD,
                   input_names=["relu1", "relu2"]),
        ]
        graph = new_graph(input_names=["input"],
                          output_names=["add0", "add1"],
                          ops=ops)
        subgraph_spec = [
            SubgraphNode(op=new_op(
                op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])),
            SubgraphNode(op=new_op(
                op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])),
            SubgraphNode(op=new_op(
                op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])),
            SubgraphNode(op=new_op(op_name="add0",
                                   op_type=OpType.ADD,
                                   input_names=["relu0", "relu1"]),
                         output_names=["add0"]),
            SubgraphNode(op=new_op(op_name="add1",
                                   op_type=OpType.ADD,
                                   input_names=["relu1", "relu2"]),
                         output_names=["add1"]),
        ]
        replaced_graph = replace_subgraph(graph, subgraph_spec)
        inp = {"input": jnp.ones((10, 32, 32, 3))}
        subgraph_model = SubgraphModel(replaced_graph, {}, {}, inp,
                                       subgraph_spec)
        lp = linear.LinopProperty().infer(subgraph_model)
        pairings = lp.pairings

        self.assertLen(pairings, 2)
        self.assertIn("add0:0", pairings)
        self.assertLen(pairings["add0:0"], 2)
        self.assertIn("dense0:0", pairings["add0:0"])
        self.assertIn("dense1:0", pairings["add0:0"])
        self.assertIn("add1:0", pairings)
        self.assertLen(pairings["add1:0"], 2)
        self.assertIn("dense1:0", pairings["add1:0"])
        self.assertIn("dense2:0", pairings["add1:0"])
Beispiel #23
0
    def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        self.subgraph_model = SubgraphModel(self.graph, self.constants, {}, {})
        self.dp = depth.DepthProperty().infer(self.subgraph_model)