Exemplo n.º 1
0
    def setUp(self):
        super().setUp()

        self.graph_dense = new_graph(
            ["input"],
            ["output"],
            [
                new_op(
                    op_name="output",
                    op_type=OpType.SOFTMAX,
                    # op_kwargs={"features": 10},
                    input_names=["input"])
            ])
        state_dense = Model(self.graph_dense).init(
            random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))})
        self.subgraph_dense = SubgraphModel(self.graph_dense, None,
                                            state_dense,
                                            {"input": jnp.ones((5, 5, 5))})
        self.lp_dense = linear.LinopProperty().infer(self.subgraph_dense)

        self.graph_conv = new_graph(["input"], ["output"], [
            new_op(op_name="output",
                   op_type=OpType.CONV,
                   op_kwargs={
                       "features": 10,
                       "kernel_size": [3, 3]
                   },
                   input_names=["input"])
        ])
        state_conv = Model(self.graph_conv).init(
            random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))})
        self.subgraph_conv = SubgraphModel(self.graph_conv, None, state_conv,
                                           {"input": jnp.ones((5, 5, 5))})
        self.lp_conv = linear.LinopProperty().infer(self.subgraph_conv)
Exemplo n.º 2
0
    def test_abstract(self):
        graphs = []

        conv_op = functools.partial(new_op,
                                    op_type=OpType.CONV,
                                    op_kwargs={
                                        "features": 10,
                                        "kernel_size": [3, 3]
                                    })
        dense_op = functools.partial(new_op,
                                     op_type=OpType.DENSE,
                                     op_kwargs={
                                         "features": 10,
                                     })

        for op_type in [
                OpType.RELU, OpType.SOFTMAX, OpType.LAYER_NORM,
                OpType.BATCH_NORM
        ]:
            for op_ctr in [conv_op, dense_op]:
                graphs.append([
                    op_ctr(input_names=["input"], op_name="other"),
                    new_op(op_name="output",
                           op_type=op_type,
                           input_names=["other"])
                ])
                graphs.append([
                    new_op(op_name="other",
                           op_type=op_type,
                           input_names=["input"]),
                    op_ctr(input_names=["other"], op_name="output"),
                ])

        input_tensor = {"input": jnp.ones((5, 5, 5, 5))}
        for graph in graphs:
            graph = new_graph(["input"], ["output"], graph)
            state = Model(graph).init(random.PRNGKey(1), input_tensor)

            # Make all the kernels positive, otherwise, the ReLU might zero out the
            # entire tensor.
            state = jax.tree_util.tree_map(abs, state)

            subg_model = SubgraphModel(graph, None, state, input_tensor)
            lp_abstract = linear.LinopProperty().infer(subg_model,
                                                       abstract=True)
            lp_concrete = linear.LinopProperty().infer(subg_model,
                                                       abstract=False)
            pairings_concerete = lp_concrete.pairings["output"][
                "input"].mappings
            pairings_abstract = lp_abstract.pairings["output"][
                "input"].mappings

            print("concrete:", pairings_concerete)
            print("abstract:", pairings_abstract)
            self.assertTrue(
                ((pairings_abstract - pairings_concerete) == 0).all())
Exemplo n.º 3
0
 def test_synthesizer_two(self):
     """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)]."""
     subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
     subg[-1].output_names = self.graph.ops[7].input_names
     subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                    self.input, subg)
     sp = shape.ShapeProperty().infer(subgraph_model,
                                      max_size=self.max_size)
     lp = linear.LinopProperty().infer(subgraph_model)
     self._synthesize(subgraph_model, [sp, lp])
 def test_synthesizer_easy_two(self):
     """Replacing [conv3x3(features = 64)]."""
     subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
     subg[-1].output_names = self.graph.ops[5].input_names
     subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                    self.input, subg)
     sp = shape.ShapeProperty().infer(subgraph_model,
                                      max_size=self.max_size)
     dp = depth.DepthProperty().infer(subgraph_model)
     lp = linear.LinopProperty().infer(subgraph_model)
     self._synthesize(subgraph_model, [sp, dp, lp])
 def test_synthesizer_hard(self):
     if not self.hard:
         return
     subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
     subg[-1].output_names = self.graph.ops[7].input_names
     subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                    self.input, subg)
     sp = shape.ShapeProperty().infer(subgraph_model,
                                      max_size=self.max_size)
     dp = depth.DepthProperty().infer(subgraph_model)
     lp = linear.LinopProperty().infer(subgraph_model)
     self._synthesize(subgraph_model, [sp, dp, lp])
    def test_synthesizer_two(self):
        """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)].

    Because we do not check for the depth property, [dense(features = 64),
    avgpool2x2(strides=2x2)] works as well (which is what is synthesized due to
    the enumeration order).
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])
    def test_synthesizer_easy_two(self):
        """Replacing [conv3x3(features = 64)].

    Because we test all three props, this is replaced by conv3x3(features = 64)
    (i.e., an identical op) due to the enumeration order.
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])
Exemplo n.º 8
0
    def test_synthesizer_resnet_big(self):
        self.graph, self.constants, _ = resnetv1.ResNet18(
            num_classes=10, input_resolution="small")
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state,
                                self.input)[self.graph.output_names[0]]
        self.max_size = int(10e8)

        subg_ops = self.graph.ops[3:5] + self.graph.ops[8:12]
        subg = [subgraph.SubgraphNode(op=o) for o in subg_ops]
        subg[-1].output_names = [f"{subg[-1].op.name}:0"]
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])
Exemplo n.º 9
0
    def test_multi_input(self):
        ops = [
            new_op(op_name="dense0",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu0",
                   op_type=OpType.RELU,
                   input_names=["dense0"]),
            new_op(op_name="dense1",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu1",
                   op_type=OpType.RELU,
                   input_names=["dense1"]),
            new_op(op_name="dense2",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu2",
                   op_type=OpType.RELU,
                   input_names=["dense2"]),
            new_op(op_name="add0",
                   op_type=OpType.ADD,
                   input_names=["relu0", "relu1"]),
            new_op(op_name="add1",
                   op_type=OpType.ADD,
                   input_names=["relu1", "relu2"]),
        ]
        graph = new_graph(input_names=["input"],
                          output_names=["add0", "add1"],
                          ops=ops)
        subgraph_spec = [
            SubgraphNode(op=new_op(
                op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])),
            SubgraphNode(op=new_op(
                op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])),
            SubgraphNode(op=new_op(
                op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])),
            SubgraphNode(op=new_op(op_name="add0",
                                   op_type=OpType.ADD,
                                   input_names=["relu0", "relu1"]),
                         output_names=["add0"]),
            SubgraphNode(op=new_op(op_name="add1",
                                   op_type=OpType.ADD,
                                   input_names=["relu1", "relu2"]),
                         output_names=["add1"]),
        ]
        replaced_graph = replace_subgraph(graph, subgraph_spec)
        inp = {"input": jnp.ones((10, 32, 32, 3))}
        subgraph_model = SubgraphModel(replaced_graph, {}, {}, inp,
                                       subgraph_spec)
        lp = linear.LinopProperty().infer(subgraph_model)
        pairings = lp.pairings

        self.assertLen(pairings, 2)
        self.assertIn("add0:0", pairings)
        self.assertLen(pairings["add0:0"], 2)
        self.assertIn("dense0:0", pairings["add0:0"])
        self.assertIn("dense1:0", pairings["add0:0"])
        self.assertIn("add1:0", pairings)
        self.assertLen(pairings["add1:0"], 2)
        self.assertIn("dense1:0", pairings["add1:0"])
        self.assertIn("dense2:0", pairings["add1:0"])
Exemplo n.º 10
0
 def test_full(self):
     graph, constants, _ = cnn.CifarNet()
     subgraph_model = SubgraphModel(graph, constants, None,
                                    {"input": jnp.ones((10, 32, 32, 3))})
     linear.LinopProperty().infer(subgraph_model)