def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        self.model = Model(self.graph, self.constants)
        self.state = self.model.init(random.PRNGKey(0), jnp.ones(
            (1, 32, 32, 3)))

        self.subgraph = [
            subgraph.SubgraphNode(op=new_op(
                op_name="conv_layer1/conv/1",
                op_type=OpType.CONV,
                op_kwargs={
                    "features": 64,
                    "kernel_size": [1, 1]
                },
                input_names=["conv_layer0/avg_pool"]), ),
            subgraph.SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1",
                                            op_type=OpType.GELU,
                                            input_names=["conv_layer1/conv/1"
                                                         ]),
                                  output_names=["conv_layer1/relu"])
        ]
        self.new_graph = subgraph.replace_subgraph(self.graph, self.subgraph)
        self.new_model = Model(self.new_graph, self.constants)
        self.new_state = self.new_model.init(random.PRNGKey(0),
                                             jnp.ones((1, 32, 32, 3)))
Beispiel #2
0
 def test_synthesizer_easy_one(self):
     """Replacing [conv3x3(features = 64)]."""
     subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
     subg[-1].output_names = self.graph.ops[5].input_names
     subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                    self.input, subg)
     sp = shape.ShapeProperty().infer(subgraph_model,
                                      max_size=self.max_size)
     dp = depth.DepthProperty().infer(subgraph_model)
     self._synthesize(subgraph_model, [sp, dp])
Beispiel #3
0
 def test_synthesizer_two(self):
     """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)]."""
     subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
     subg[-1].output_names = self.graph.ops[7].input_names
     subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                    self.input, subg)
     sp = shape.ShapeProperty().infer(subgraph_model,
                                      max_size=self.max_size)
     lp = linear.LinopProperty().infer(subgraph_model)
     self._synthesize(subgraph_model, [sp, lp])
 def test_synthesizer_hard(self):
     if not self.hard:
         return
     subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
     subg[-1].output_names = self.graph.ops[7].input_names
     subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                    self.input, subg)
     sp = shape.ShapeProperty().infer(subgraph_model,
                                      max_size=self.max_size)
     dp = depth.DepthProperty().infer(subgraph_model)
     lp = linear.LinopProperty().infer(subgraph_model)
     self._synthesize(subgraph_model, [sp, dp, lp])
    def test_multi_input_output(self):
        """Tests a subgraph substitution on a graph with multiple inputs / output ops.

    We use a ResNet model, which has skip connections. This test checks that the
    substitution produces the expected number of ops, and also that the newly
    produced graph is still executable.
    """

        graph, constants, _ = resnetv1.ResNet18(num_classes=10,
                                                input_resolution="small")
        model = Model(graph, constants)
        state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
        y = model.apply(state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))

        subg = [
            subgraph.SubgraphNode(
                op=new_op(op_name="subgraph/conv0",
                          op_type=OpType.CONV,
                          op_kwargs={
                              "features": 64,
                              "kernel_size": [1, 1]
                          },
                          input_names=["resnet11/skip/relu1"])),
            subgraph.SubgraphNode(
                op=new_op(op_name="subgraph/gelu1",
                          op_type=OpType.GELU,
                          input_names=["subgraph/conv0"]),
                output_names=["resnet_stride1_filtermul1_basic12/relu2"])
        ]
        new_graph = subgraph.replace_subgraph(graph, subg)

        # the subgraph is 2 ops (conv / gelu) replacing 3 ops (conv / bn / relu)
        self.assertLen(graph.ops, len(new_graph.ops) + 1)

        new_model = Model(new_graph, constants)
        new_state = new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))

        y = new_model.apply(new_state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))
    def test_synthesizer_easy_one(self):
        """Replacing [conv3x3(features = 64)].

    Because we do not test linear, this is replaced by dense3x3(features = 64)
    due to the enumeration order.
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp])
    def test_synthesizer_two(self):
        """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)].

    Because we do not check for the depth property, [dense(features = 64),
    avgpool2x2(strides=2x2)] works as well (which is what is synthesized due to
    the enumeration order).
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])
    def test_synthesizer_easy_two(self):
        """Replacing [conv3x3(features = 64)].

    Because we test all three props, this is replaced by conv3x3(features = 64)
    (i.e., an identical op) due to the enumeration order.
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])
Beispiel #9
0
    def test_synthesizer_resnet_big(self):
        self.graph, self.constants, _ = resnetv1.ResNet18(
            num_classes=10, input_resolution="small")
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state,
                                self.input)[self.graph.output_names[0]]
        self.max_size = int(10e8)

        subg_ops = self.graph.ops[3:5] + self.graph.ops[8:12]
        subg = [subgraph.SubgraphNode(op=o) for o in subg_ops]
        subg[-1].output_names = [f"{subg[-1].op.name}:0"]
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])
Beispiel #10
0
  def make_subgraph_spec(
      self,
      subg,
      adjust_features = False):
    """Converts a list of ops into a (sequential) subgraph spec."""
    input_name = self.input_name
    subgraph_spec = []

    # Check to see if need to unique-ify ops.
    op_names_unique = True
    idxs = []
    for op in subg:
      splits = op.name.split("/")
      idx = splits[-1]
      if len(splits) == 1 or idx in idxs or not idx.isdigit():
        op_names_unique = False
        break
      idxs.append(idx)

    for idx, op in enumerate(subg):
      op = copy.deepcopy(op)
      if not op_names_unique:
        op.name = f"{op.name}/{idx}"
      op.input_names = [input_name]
      input_name = op.name + ":0"
      subgraph_spec.append(subgraph.SubgraphNode(op))
    subgraph_spec[-1].output_names = [self.output_name]

    if adjust_features:
      output_features = self.get_output_features(subg)
      if output_features:
        for node in subgraph_spec[::-1]:
          if node.op.type == OpType.DENSE or node.op.type == OpType.CONV:
            node.op.op_kwargs["features"] = output_features
            break
    return subgraph_spec