def test_abstract_sequential_synthesizer_output_features(self): graph, constants, _ = cnn.CifarNet() subgraph_spec = [ SubgraphNode( op=new_op( op_name="conv_layer1/conv", op_type=OpType.CONV, op_kwargs={ "features": "S:-1*2", "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]),), SubgraphNode( op=new_op( op_name="conv_layer1/relu", op_type=OpType.RELU, input_names=["conv_layer1/conv"]), output_names=["conv_layer1/relu"]) ] subgraph = replace_subgraph(graph, subgraph_spec) subgraph_model = SubgraphModel(subgraph, constants, None, {"input": jnp.zeros((5, 32, 32, 10))}, subgraph_spec) sp = shape.ShapeProperty().infer(subgraph_model) syn = TestSequentialSynthesizer([(subgraph_model, [sp])], 0) self.assertEqual(syn.output_features_mul, 2) self.assertEqual(syn.output_features_div, 1)
def setUp(self): super().setUp() self.graph, self.constants, _ = cnn.CifarNet() self.model = Model(self.graph, self.constants) self.state = self.model.init(random.PRNGKey(0), jnp.ones( (1, 32, 32, 3))) self.subgraph = [ subgraph.SubgraphNode(op=new_op( op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), ), subgraph.SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv/1" ]), output_names=["conv_layer1/relu"]) ] self.new_graph = subgraph.replace_subgraph(self.graph, self.subgraph) self.new_model = Model(self.new_graph, self.constants) self.new_state = self.new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
def test_abstract_sequential_synthesizer_fail(self): graph, constants, _ = cnn.CifarNet() subgraph_spec = [ SubgraphNode( op=new_op( op_name="conv_layer1/conv/1", op_type=OpType.CONV, op_kwargs={ "features": 64, "kernel_size": [1, 1] }, input_names=["conv_layer0/avg_pool"]), output_names=["conv_layer1/conv"]), SubgraphNode( op=new_op( op_name="conv_layer1/gelu/1", op_type=OpType.GELU, input_names=["conv_layer1/conv"]), output_names=["conv_layer1/relu"]) ] subgraph = SubgraphModel(graph, constants, None, {"input": jnp.zeros((5, 32, 32, 10))}, subgraph_spec) self.assertRaisesRegex(ValueError, ".*exactly one input.*", TestSequentialSynthesizer, [(subgraph, [])], 0)
def setUp(self): super().setUp() graph, constants, _ = cnn.CifarNet() self.cnn = Model(graph, constants) self.cnn_state = self.cnn.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
def setUp(self): super().setUp() self.graph, self.constants, _ = cnn.CifarNet() state = Model(self.graph, self.constants).init(random.PRNGKey(0), {"input": jnp.ones((5, 32, 32, 3))}) self.subgraph_model = SubgraphModel( self.graph, self.constants, state, {"input": jnp.ones((5, 32, 32, 3))}) self.sp = shape.ShapeProperty().infer(self.subgraph_model)
def setUp(self): super().setUp() self.graph, self.constants, _ = cnn.CifarNet() self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)["fc/logits"] self.max_size = int(10e8) self.hard = False
def setUp(self): super().setUp() seed = int(time.time()) logging.info("Seed: %d", seed) py_rand.seed(seed) self.graph, self.constants, _ = cnn.CifarNet() self.m = Model(self.graph, self.constants) self.input = {"input": jnp.ones((5, 32, 32, 3))} self.state = self.m.init(random.PRNGKey(0), self.input) self.out = self.m.apply(self.state, self.input)[self.graph.output_names[0]] self.max_size = int(10e8)
def test_construct(self): random.seed(0) for _ in range(20): graph, constants, _ = cnn.CifarNet() mutator = RandomSequentialMutator( [ShapeProperty(), DepthProperty(), LinopProperty()], p=0.5, max_len=3) mutated = mutator.mutate(graph, [(constants, None, { "input": jnp.ones((1, 32, 32, 3)) })], abstract=True) self.assertLen(mutated, 1)
def setUp(self): super().setUp() self.graph, self.constants, _ = cnn.CifarNet() self.subgraph_model = SubgraphModel(self.graph, self.constants, {}, {}) self.dp = depth.DepthProperty().infer(self.subgraph_model)
def test_full(self): graph, constants, _ = cnn.CifarNet() subgraph_model = SubgraphModel(graph, constants, None, {"input": jnp.ones((10, 32, 32, 3))}) linear.LinopProperty().infer(subgraph_model)
def cnn_model_fn(*args, **kwargs): return cnn.CifarNet(*args, **kwargs)