def test_abstract_sequential_synthesizer_output_features(self):
   graph, constants, _ = cnn.CifarNet()
   subgraph_spec = [
       SubgraphNode(
           op=new_op(
               op_name="conv_layer1/conv",
               op_type=OpType.CONV,
               op_kwargs={
                   "features": "S:-1*2",
                   "kernel_size": [1, 1]
               },
               input_names=["conv_layer0/avg_pool"]),),
       SubgraphNode(
           op=new_op(
               op_name="conv_layer1/relu",
               op_type=OpType.RELU,
               input_names=["conv_layer1/conv"]),
           output_names=["conv_layer1/relu"])
   ]
   subgraph = replace_subgraph(graph, subgraph_spec)
   subgraph_model = SubgraphModel(subgraph, constants, None,
                                  {"input": jnp.zeros((5, 32, 32, 10))},
                                  subgraph_spec)
   sp = shape.ShapeProperty().infer(subgraph_model)
   syn = TestSequentialSynthesizer([(subgraph_model, [sp])], 0)
   self.assertEqual(syn.output_features_mul, 2)
   self.assertEqual(syn.output_features_div, 1)
    def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        self.model = Model(self.graph, self.constants)
        self.state = self.model.init(random.PRNGKey(0), jnp.ones(
            (1, 32, 32, 3)))

        self.subgraph = [
            subgraph.SubgraphNode(op=new_op(
                op_name="conv_layer1/conv/1",
                op_type=OpType.CONV,
                op_kwargs={
                    "features": 64,
                    "kernel_size": [1, 1]
                },
                input_names=["conv_layer0/avg_pool"]), ),
            subgraph.SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1",
                                            op_type=OpType.GELU,
                                            input_names=["conv_layer1/conv/1"
                                                         ]),
                                  output_names=["conv_layer1/relu"])
        ]
        self.new_graph = subgraph.replace_subgraph(self.graph, self.subgraph)
        self.new_model = Model(self.new_graph, self.constants)
        self.new_state = self.new_model.init(random.PRNGKey(0),
                                             jnp.ones((1, 32, 32, 3)))
 def test_abstract_sequential_synthesizer_fail(self):
   graph, constants, _ = cnn.CifarNet()
   subgraph_spec = [
       SubgraphNode(
           op=new_op(
               op_name="conv_layer1/conv/1",
               op_type=OpType.CONV,
               op_kwargs={
                   "features": 64,
                   "kernel_size": [1, 1]
               },
               input_names=["conv_layer0/avg_pool"]),
           output_names=["conv_layer1/conv"]),
       SubgraphNode(
           op=new_op(
               op_name="conv_layer1/gelu/1",
               op_type=OpType.GELU,
               input_names=["conv_layer1/conv"]),
           output_names=["conv_layer1/relu"])
   ]
   subgraph = SubgraphModel(graph, constants, None,
                            {"input": jnp.zeros((5, 32, 32, 10))},
                            subgraph_spec)
   self.assertRaisesRegex(ValueError, ".*exactly one input.*",
                          TestSequentialSynthesizer, [(subgraph, [])], 0)
Exemple #4
0
    def setUp(self):
        super().setUp()

        graph, constants, _ = cnn.CifarNet()
        self.cnn = Model(graph, constants)
        self.cnn_state = self.cnn.init(random.PRNGKey(0),
                                       jnp.ones((1, 32, 32, 3)))
Exemple #5
0
    def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        state = Model(self.graph,
                      self.constants).init(random.PRNGKey(0),
                                           {"input": jnp.ones((5, 32, 32, 3))})
        self.subgraph_model = SubgraphModel(
            self.graph, self.constants, state,
            {"input": jnp.ones((5, 32, 32, 3))})
        self.sp = shape.ShapeProperty().infer(self.subgraph_model)
    def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state, self.input)["fc/logits"]
        self.max_size = int(10e8)

        self.hard = False
Exemple #7
0
    def setUp(self):
        super().setUp()

        seed = int(time.time())
        logging.info("Seed: %d", seed)
        py_rand.seed(seed)

        self.graph, self.constants, _ = cnn.CifarNet()
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state,
                                self.input)[self.graph.output_names[0]]
        self.max_size = int(10e8)
Exemple #8
0
 def test_construct(self):
     random.seed(0)
     for _ in range(20):
         graph, constants, _ = cnn.CifarNet()
         mutator = RandomSequentialMutator(
             [ShapeProperty(),
              DepthProperty(),
              LinopProperty()],
             p=0.5,
             max_len=3)
         mutated = mutator.mutate(graph, [(constants, None, {
             "input": jnp.ones((1, 32, 32, 3))
         })],
                                  abstract=True)
         self.assertLen(mutated, 1)
Exemple #9
0
    def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        self.subgraph_model = SubgraphModel(self.graph, self.constants, {}, {})
        self.dp = depth.DepthProperty().infer(self.subgraph_model)
Exemple #10
0
 def test_full(self):
     graph, constants, _ = cnn.CifarNet()
     subgraph_model = SubgraphModel(graph, constants, None,
                                    {"input": jnp.ones((10, 32, 32, 3))})
     linear.LinopProperty().infer(subgraph_model)
Exemple #11
0
 def cnn_model_fn(*args, **kwargs):
     return cnn.CifarNet(*args, **kwargs)