def setUp(self): super().setUp() self.graph_dense = new_graph( ["input"], ["output"], [ new_op( op_name="output", op_type=OpType.SOFTMAX, # op_kwargs={"features": 10}, input_names=["input"]) ]) state_dense = Model(self.graph_dense).init( random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))}) self.subgraph_dense = SubgraphModel(self.graph_dense, None, state_dense, {"input": jnp.ones((5, 5, 5))}) self.lp_dense = linear.LinopProperty().infer(self.subgraph_dense) self.graph_conv = new_graph(["input"], ["output"], [ new_op(op_name="output", op_type=OpType.CONV, op_kwargs={ "features": 10, "kernel_size": [3, 3] }, input_names=["input"]) ]) state_conv = Model(self.graph_conv).init( random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))}) self.subgraph_conv = SubgraphModel(self.graph_conv, None, state_conv, {"input": jnp.ones((5, 5, 5))}) self.lp_conv = linear.LinopProperty().infer(self.subgraph_conv)
def mlp_block(dropout, mlp_factor): """MLP block in the encoder block.""" ops = [] append = functools.partial(append_op, ops) use_dropout = dropout > 1e-3 input_name = "input" append(op_name="dense0", op_type=OpType.DENSE, op_kwargs={ "features": f"S:-1*{mlp_factor}", "kernel_init": "I:xavier_uniform", "bias_init": "I:normal:stddev:1e-6" }, input_names=[input_name]) append(op_name="gelu1", op_type=OpType.GELU) if use_dropout: append(op_name="dropout", op_type=OpType.DROPOUT, op_kwargs={"rate": dropout}) append(op_name="dense2", op_type=OpType.DENSE, op_kwargs={ "features": f"S:-1%{mlp_factor}", "kernel_init": "I:xavier_uniform", "bias_init": "I:normal:stddev:1e-6" }) output_name = ops[-1].name graph = new_graph(input_names=[input_name], output_names=[output_name], ops=ops) return Block(name=f"mlp_block{'_dropout' if use_dropout else ''}", graph=graph, constants={})
def encoder( input_name, blocks, dropout, ): """Encoder for ViT.""" ops = [] constants = {} new_blocks = [] append = functools.partial(append_op, ops) encoder_input_name = input_name for block_id, block in enumerate(blocks): graph, new_constants, new_block = encoder_block( input_name=encoder_input_name, block_id=block_id, block=block, dropout=dropout) encoder_input_name = graph.output_names[0] constants.update(new_constants) new_blocks.append(new_block) ops.extend(graph.ops) append(op_name="transformer/encoder_norm", op_type=OpType.LAYER_NORM) output_name = ops[-1].name graph = new_graph( input_names=[input_name], output_names=[output_name], ops=ops) return graph, constants, new_blocks
def _update_pairings(self, op, in_shapes): """Updates pairings with the property for a single op. Args: op: The op for which to infer the pairing property. in_shapes: the shapes of the input tensors. """ assert len(op.input_names) == len(in_shapes) input_values = { input_name: jnp.ones(in_shape) for input_name, in_shape in zip(op.input_names, in_shapes) } output_names = [f"{op.name}:{i}" for i in range(op.num_outputs)] graph = new_graph(op.input_names, output_names, [op]) model = Model(graph) state = model.init(jax.random.PRNGKey(0), input_values) pairings = GraphPairings.infer( model, input_values, state, abstract=False).pairings new_pairings = {} for output_idx, output_name in enumerate(output_names): new_pairings[output_idx] = {} for input_idx, input_name in enumerate(op.input_names): new_pairings[output_idx][input_idx] = pairings[output_name][input_name] key = self.hash(op, in_shapes) self.pairings[key] = new_pairings
def test_unsatisfy(self): # This test removes the last dense layer, so the new graph should be less # deep. graph = new_graph(input_names=["input"], output_names=["fc/relu"], ops=self.graph.ops) subgraph_model = SubgraphModel(graph, self.constants, {}, {}) self.assertFalse(self.dp.verify(subgraph_model))
def test_equal(self): """Tests whether the fingerprint is the same for equivalent graphs. The ops have different names and also have different topological sort. """ ops1 = [ new_op(op_name="dense", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="conv", op_type=OpType.CONV, op_kwargs={ "features": 32, "kernel_size": [3] }, input_names=["input"]), new_op(op_name="output", op_type=OpType.ADD, input_names=["dense", "conv"]), ] graph1 = new_graph(["input"], ["output"], ops1) ops2 = [ new_op(op_name="conv2", op_type=OpType.CONV, op_kwargs={ "features": 32, "kernel_size": [3] }, input_names=["input"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="output", op_type=OpType.ADD, input_names=["dense2", "conv2"]), ] graph2 = new_graph(["input"], ["output"], ops2) input_dict = {"input": jnp.ones((5, 5, 5))} fingerprint1 = fingerprint.fingerprint_graph(graph1, {}, input_dict) fingerprint2 = fingerprint.fingerprint_graph(graph2, {}, input_dict) self.assertEqual(fingerprint1, fingerprint2)
def mb_conv_block(expand_ratio, stride, kernel_size, output_filters): """Returns an Efficientnet MBConvBlock.""" input_name = "input" ops = [] append = functools.partial(append_op, ops) # Expand. if expand_ratio > 1: append(op_name="mb_conv/expand/conv0", op_type=OpType.CONV, op_kwargs={ "features": f"S:-1*{expand_ratio}", "kernel_size": 1, "strides": 1, }, input_names=[input_name]) append(op_name="mb_conv/expand/bn1", op_type=OpType.BATCH_NORM) append(op_name="mb_conv/expand/swish2", op_type=OpType.SWISH) input_name = ops[-1].name # Depthwise conv. append(op_name="mb_conv/dw/conv0", op_type=OpType.CONV, op_kwargs={ "features": "S:-1", "feature_group_count": "S:-1", "kernel_size": kernel_size, "strides": stride, }, input_names=[input_name]) append(op_name="mb_conv/dw/bn1", op_type=OpType.BATCH_NORM) append(op_name="mb_conv/dw/swish2", op_type=OpType.SWISH) # Squeeze and excitation. input_name = ops[-1].name se_ops = squeeze_excite(input_name, expand_ratio * 4) ops.extend(se_ops) # Output. append(op_name="mb_conv/output/conv0", op_type=OpType.CONV, op_kwargs={ "features": "S:input:-1" if not output_filters else output_filters, "kernel_size": 1, "strides": 1, }) append(op_name="mb_conv/output/bn1", op_type=OpType.BATCH_NORM) output_name = ops[-1].name graph = new_graph(input_names=["input"], output_names=[output_name], ops=ops) return Block( name=f"mbconv_expand{expand_ratio}_stride{stride}_kernel{kernel_size}_" f"outputfilters{output_filters}_", graph=graph)
def __init__(self, graph, constants, state, inputs, subgraph = None): self.graph = graph self.constants = constants self.state = state self.inputs = inputs self.subgraph: SubgraphSpec = subgraph if subgraph else [] self.input_names = None self.output_names = None self.original_outputs = graph.output_names if subgraph: self._subgraph_to_names() # graph for graph inputs -> subg inputs self.subg_inputs_graph = copy.deepcopy(graph) self.subg_inputs_graph.output_names = self.input_names self.subg_inputs_model = Model(self.subg_inputs_graph, self.constants) self.subg_inputs = None # graph for graph inputs -> subg outputs self.subg_outputs_graph = copy.deepcopy(graph) self.subg_outputs_graph.output_names = self.output_names self.subg_outputs_model = Model(self.subg_outputs_graph, self.constants) self.subg_outputs = None # graph for subg inputs -> subg outputs subg_ops = [node.op for node in subgraph] self.subg_graph = new_graph(self.input_names, self.output_names, subg_ops) self.subg_model = Model(self.subg_graph, self.constants) else: self.input_names = [ canonicalize_tensor_name(name) for name in graph.input_names ] self.output_names = [ canonicalize_tensor_name(name) for name in graph.output_names ] # subg inputs = inputs to the graph self.subg_inputs_graph = None self.subg_inputs_model = None self.subg_inputs = inputs # graph for graph inputs -> subg outputs self.subg_outputs_graph = copy.deepcopy(graph) self.subg_outputs_model = Model(self.subg_outputs_graph, self.constants) self.subg_outputs = None # subg outputs = full graph outputs self.subg_graph = self.subg_outputs_graph self.subg_model = self.subg_outputs_model
def test_abstract(self): graphs = [] conv_op = functools.partial(new_op, op_type=OpType.CONV, op_kwargs={ "features": 10, "kernel_size": [3, 3] }) dense_op = functools.partial(new_op, op_type=OpType.DENSE, op_kwargs={ "features": 10, }) for op_type in [ OpType.RELU, OpType.SOFTMAX, OpType.LAYER_NORM, OpType.BATCH_NORM ]: for op_ctr in [conv_op, dense_op]: graphs.append([ op_ctr(input_names=["input"], op_name="other"), new_op(op_name="output", op_type=op_type, input_names=["other"]) ]) graphs.append([ new_op(op_name="other", op_type=op_type, input_names=["input"]), op_ctr(input_names=["other"], op_name="output"), ]) input_tensor = {"input": jnp.ones((5, 5, 5, 5))} for graph in graphs: graph = new_graph(["input"], ["output"], graph) state = Model(graph).init(random.PRNGKey(1), input_tensor) # Make all the kernels positive, otherwise, the ReLU might zero out the # entire tensor. state = jax.tree_util.tree_map(abs, state) subg_model = SubgraphModel(graph, None, state, input_tensor) lp_abstract = linear.LinopProperty().infer(subg_model, abstract=True) lp_concrete = linear.LinopProperty().infer(subg_model, abstract=False) pairings_concerete = lp_concrete.pairings["output"][ "input"].mappings pairings_abstract = lp_abstract.pairings["output"][ "input"].mappings print("concrete:", pairings_concerete) print("abstract:", pairings_abstract) self.assertTrue( ((pairings_abstract - pairings_concerete) == 0).all())
def test_satisfy(self): # This test removes the last dense layer, so the old graph should be more # deep. ops = self.graph.ops[:-1] ops[-1].name = "fc/logits" graph = new_graph(input_names=["input"], output_names=["fc/logits"], ops=ops) subgraph_model = SubgraphModel(graph, self.constants, {}, {}) dp = depth.DepthProperty().infer(subgraph_model) self.assertTrue(dp.verify(self.subgraph_model))
def test_unsatisfy(self): # This test removes the last dense layer, so the new graph should have a # different shape (and therefore not satisfy the inferred shape property). graph = new_graph(input_names=["input"], output_names=["fc/relu"], ops=self.graph.ops) state = Model(graph, self.constants).init(random.PRNGKey(0), {"input": jnp.ones((5, 32, 32, 3))}) subgraph_model = SubgraphModel(graph, self.constants, state, {"input": jnp.ones((5, 32, 32, 3))}) self.assertFalse(self.sp.verify(subgraph_model))
def test_not_equal(self): """Tests whether the fingerprint is different for non-equivalent graphs.""" ops1 = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="output", op_type=OpType.ADD, input_names=["dense0", "dense1"]), ] graph1 = new_graph(["input"], ["output"], ops1) ops2 = [ new_op(op_name="conv2", op_type=OpType.CONV, op_kwargs={ "features": 32, "kernel_size": [3] }, input_names=["input"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="output", op_type=OpType.ADD, input_names=["dense2", "conv2"]), ] graph2 = new_graph(["input"], ["output"], ops2) input_dict = {"input": jnp.ones((5, 5, 5))} fingerprint1 = fingerprint.fingerprint_graph(graph1, {}, input_dict) fingerprint2 = fingerprint.fingerprint_graph(graph2, {}, input_dict) self.assertNotEqual(fingerprint1, fingerprint2)
def conv_net(in_features, out_features, num_classes, blocks=None): """Graph for 3-layer CNN.""" if not blocks: blocks = [block_type() for block_type in BLOCK_TYPES] input_name = "input" new_blocks = [] ops = [ new_op(op_name="proj", op_type=OpType.CONV, op_kwargs={ "features": in_features, "kernel_size": 1, }, input_names=[input_name]) ] constants = {} block_input_name = ops[-1].name for idx, block in enumerate(blocks): block = block.instantiate(input_names=[block_input_name], instance_id=idx) new_blocks.append(block) constants.update(block.constants) ops.extend(block.graph.ops) block_input_name = ops[-1].name constants.update({ "out_features": out_features, "num_classes": num_classes }) ops.extend([ new_op(op_name="flatten", op_type=OpType.FLATTEN, input_names=[ops[-1].name]), new_op(op_name="fc/dense", op_type=OpType.DENSE, op_kwargs={"features": "K:out_features"}, input_names=["flatten"]), new_op(op_name="fc/relu", op_type=OpType.RELU, input_names=["fc/dense"]), new_op(op_name="fc/logits", op_type=OpType.DENSE, op_kwargs={"features": "K:num_classes"}, input_names=["fc/relu"]) ]) graph = new_graph(input_names=[input_name], output_names=["fc/logits"], ops=ops) return graph, constants, new_blocks
def test_identical(self): """Tests whether the fingerprint is the same for identical graphs.""" ops = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="output", op_type=OpType.ADD, input_names=["dense0", "dense1"]), ] graph = new_graph(["input"], ["output"], ops) input_dict = {"input": jnp.ones((5, 5, 5))} fingerprint1 = fingerprint.fingerprint_graph(graph, {}, input_dict) fingerprint2 = fingerprint.fingerprint_graph(graph, {}, input_dict) self.assertEqual(fingerprint1, fingerprint2)
def mbconv_layer(block, input_name, block_id, output_filters, stride, layer_drop_rate): """Returns a MBConv layer.""" prefix = f"mbconv{block_id}" block = block.instantiate(input_names=[input_name], instance_id=block_id) constants = block.constants ops = list(block.graph.ops) if not output_filters and stride == 1: append_op(ops, op_name=f"{prefix}/skip", op_type=OpType.ADD, input_kwargs={"layer_drop_rate": layer_drop_rate}, input_names=[ops[-1].name, input_name]) output_name = ops[-1].name graph = new_graph(input_names=[input_name], output_names=[output_name], ops=ops) return graph, constants, block
def encoder_block( input_name, block_id, block, dropout, ): """Returns an encoder block.""" prefix = f"encoder{block_id}" ops = [] append = functools.partial(append_op, ops) use_dropout = dropout > 1e-3 res_input = input_name append(op_name=f"{prefix}/layernorm0", op_type=OpType.LAYER_NORM, input_names=[input_name]) block = block.instantiate( input_names=[ops[-1].name], instance_id=block_id) ops.extend(block.graph.ops) constants = block.constants if use_dropout: append(op_name=f"{prefix}/dropout1", op_type=OpType.DROPOUT, op_kwargs={"rate": dropout}) append(op_name=f"{prefix}/residual1", op_type=OpType.ADD, input_names=[res_input, ops[-1].name]) output_name = ops[-1].name graph = new_graph(input_names=[input_name], output_names=[output_name], ops=ops) return graph, constants, block
def conv_block(): """Makes a conv block parameterized by the number of features.""" ops = [ new_op(op_name="conv", op_type=OpType.CONV, op_kwargs={ "features": "S:-1*2", "kernel_size": 3 }, input_names=["input"]), new_op(op_name="relu", op_type=OpType.RELU, input_names=["conv"]), new_op(op_name="avg_pool", op_type=OpType.AVG_POOL, input_names=["relu"], input_kwargs={ "window_shape": 2, "strides": 2 }), ] graph = new_graph(input_names=["input"], output_names=["avg_pool"], ops=ops) return Block(name="conv_layer", graph=graph)
def mhdpa_block(spatial): """Multi headed dot product (self) attention block in encoder block.""" input_name = "input" constants = {"num_heads": None, "head_dim": None} width = "S:input:-1" num_heads = "K:num_heads" head_dim = "K:head_dim" ops = [] append = functools.partial(append_op, ops) append( op_name="value/pre", op_type=OpType.DENSE, op_kwargs={ "features": "S:-1", "kernel_init": "I:xavier_uniform", "bias_init": "I:zeros" }, input_names=[input_name]) append( op_name="key/pre", op_type=OpType.DENSE, op_kwargs={ "features": "S:-1", "kernel_init": "I:xavier_uniform", "bias_init": "I:zeros" }, input_names=[input_name]) append( op_name="query/pre", op_type=OpType.DENSE, op_kwargs={ "features": "S:-1", "kernel_init": "I:xavier_uniform", "bias_init": "I:zeros" }, input_names=[input_name]) # spatial: [b, h, w, width] -> [b, h*w, num_heads, head_dim] # original: [b, h*w, width] -> [b, h*w, num_heads, head_dim] new_shape = ["B", -1, num_heads, head_dim] append( op_name="query", op_type=OpType.RESHAPE, input_kwargs={"new_shape": new_shape}, input_names=["query/pre"]) append( op_name="key", op_type=OpType.RESHAPE, input_kwargs={"new_shape": new_shape}, input_names=["key/pre"]) append( op_name="value", op_type=OpType.RESHAPE, input_kwargs={"new_shape": new_shape}, input_names=["value/pre"]) append(op_name="query/scale", op_type=OpType.SCALAR_MUL, input_names=["query"]) # attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key) append( op_name="attn_weight", op_type=OpType.EINSUM, input_kwargs={"sum": "...qhd,...khd->...hqk"}, input_names=["query/scale", "key"]) append( op_name="attn_weight/softmax", op_type=OpType.SOFTMAX, input_kwargs={"axis": -1}) # attn_values = jnp.einsum('...hqk,...khd->...qhd', attn_weights, value) append( op_name="attn_value", op_type=OpType.EINSUM, input_kwargs={"sum": "...hqk,...khd->...qhd"}, input_names=["attn_weight/softmax", "value"]) # back to the original inputs dimensions if spatial: # [b, h*w, num_heads, head_dim] -> [b, h, w, width] new_shape = ["B", "S:input:1", "S:input:2", width] else: # [b, h*w, num_heads, head_dim] -> [b, h*w, width] new_shape = ["B", -1, width] append( op_name="attn_value/reshape", op_type=OpType.RESHAPE, input_kwargs={"new_shape": new_shape}) append( op_name="out", op_type=OpType.DENSE, op_kwargs={ "features": "S:-1", "kernel_init": "I:xavier_uniform", "bias_init": "I:zeros", }) output_name = ops[-1].name graph = new_graph(input_names=[input_name], output_names=[output_name], ops=ops) return Block(name=f"mhdpa_block{'_spatial' if spatial else ''}", graph=graph, constants=constants)
def instantiate(self, input_names, instance_id=None, constants=None): """Instantiates a version of the block with unique names. This method uses the names of graph and constants from the initial definition of the block (__init__) , so that one can instantiate from any derived block with same effect, e.g., if we have: init_block = block.__init__(name="conv_layer", ...) block0 = init_block.instantiate(instance_id=0, ...) then: block1 = init_block.instantiate(instance_id=1, ...) will have the same effect as: block1 = block0.instantiate(instance_id=1, ...) The one caveat is that the default values for unspecified constants are inherited from the instantiating block (instead of the initial definition). Args: input_names: The input tensor names the instantiated block will consume. instance_id: An id to make the names in the instantiated block unique. The id should be unique within a graph. constants: Updated parameters for the instantiated block. Returns: An instantiated block. Raises: ValueError: if the number of input names provided does not equal the number of inputs consumed by the graph. """ if len(input_names) != len(self.base_graph.input_names): raise ValueError("Wrong number of inputs provided.") prefix = "" if self.name: prefix += self.name if instance_id is not None: prefix += str(instance_id) if prefix: prefix += "/" if not constants: constants = dict(self.base_constants) new_input_names = input_names updated_names = { o: n for o, n in zip(self.base_graph.input_names, new_input_names) } inputs_names = [ canonicalize_tensor_name(n) for n in self.base_graph.input_names ] updated_names.update( {o: n for o, n in zip(inputs_names, new_input_names)}) # Update ops. new_ops = [] for op in self.base_graph.ops: # Update all input tensor names. # Any internal inputs (i.e., anything that is not a graph input) needs to # be updated with the prefix. new_inputs = [] for inp in op.input_names: try: idx = inputs_names.index(inp) new_inputs.append(new_input_names[idx]) except ValueError: new_inputs.append(f"{prefix}{inp}") # Update symbolic constant names in input_kwargs and op_kwargs. new_kwargs = [] for kwargs in [op.input_kwargs, op.op_kwargs]: nk = { k: _prefix_symbolic(v, prefix, constants, updated_names) for k, v in kwargs.items() } new_kwargs.append(nk) new_ops.append( new_op(op_name=f"{prefix}{op.name}", op_type=op.type, input_names=new_inputs, input_kwargs=new_kwargs[0], op_kwargs=new_kwargs[1], num_outputs=op.num_outputs)) # Update constants and prefix symbolic constant names. old_constants = dict(self.base_constants) if constants: old_constants.update(constants) new_constants = {f"{prefix}{k}": v for k, v in old_constants.items()} # Prefix graph output names. new_output_names = [ f"{prefix}{on}" for on in self.base_graph.output_names ] graph = new_graph(ops=new_ops, input_names=new_input_names, output_names=new_output_names) return Block(name=self.name, graph=graph, constants=new_constants, base_graph=self.base_graph, base_constants=old_constants)
def vit( blocks, patch_size, image_size, width, dropout, num_classes, spatial, ): """Graph for ViT.""" assert image_size % patch_size == 0 ops = [] append = functools.partial(append_op, ops) use_dropout = dropout > 1e-3 append(op_name="embedding", op_type=OpType.CONV, op_kwargs={ "features": width, "kernel_size": [patch_size, patch_size], "strides": [patch_size, patch_size], "padding": "VALID", }, input_names=["input"]) if spatial: # could also be [1, sequence_size, sequence_size, width] pos_embedding_shape = [ 1, f"S:{ops[-1].name}:1", f"S:{ops[-1].name}:2", f"S:{ops[-1].name}:3" ] else: append( op_name="reshape", op_type=OpType.RESHAPE, input_kwargs={"new_shape": ["B", -1, "S:-1"]}) # could also be [1, sequence_size**2, width] pos_embedding_shape = [1, f"S:{ops[-1].name}:1", f"S:{ops[-1].name}:2"] append(op_name="transformer/pos_embedding", op_type=OpType.PARAM, input_kwargs={ "shape": pos_embedding_shape, "init_fn": f"I:normal:stddev:{1/math.sqrt(width):.03f}" }, input_names=[]) append(op_name="transformer/pos_embedding/add", op_type=OpType.ADD, input_names=[ops[-2].name, ops[-1].name]) if use_dropout: append(op_name="transformer/dropout", op_type=OpType.DROPOUT, op_kwargs={"rate": dropout}) graph, constants, blocks = encoder( input_name=ops[-1].name, blocks=blocks, dropout=dropout) ops.extend(graph.ops) if spatial: append(op_name="reshape", op_type=OpType.RESHAPE, input_kwargs={"new_shape": ["B", -1, "S:-1"]}) append(op_name="gap", op_type=OpType.MEAN, input_kwargs={"axis": 1}) append(op_name="head", op_type=OpType.DENSE, op_kwargs={ "features": num_classes, "kernel_init": "I:zeros" }) graph = new_graph(input_names=["input"], output_names=["head"], ops=ops) return graph, constants, blocks
def test_multi_input(self): ops = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu0", op_type=OpType.RELU, input_names=["dense0"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu1", op_type=OpType.RELU, input_names=["dense1"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu2", op_type=OpType.RELU, input_names=["dense2"]), new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), ] graph = new_graph(input_names=["input"], output_names=["add0", "add1"], ops=ops) subgraph_spec = [ SubgraphNode(op=new_op( op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])), SubgraphNode(op=new_op( op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])), SubgraphNode(op=new_op( op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])), SubgraphNode(op=new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), output_names=["add0"]), SubgraphNode(op=new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), output_names=["add1"]), ] replaced_graph = replace_subgraph(graph, subgraph_spec) subgraph_model = SubgraphModel(replaced_graph, {}, {}, {}, subgraph_spec) dp = depth.DepthProperty().infer(subgraph_model) depth_map = dp.depth_map self.assertLen(depth_map, 3) self.assertIn("dense0:0", depth_map) self.assertIn("dense1:0", depth_map) self.assertIn("dense2:0", depth_map) self.assertLen(depth_map["dense0:0"], 1) self.assertEqual(depth_map["dense0:0"]["add0:0"], 2) self.assertLen(depth_map["dense1:0"], 2) self.assertEqual(depth_map["dense1:0"]["add0:0"], 2) self.assertEqual(depth_map["dense1:0"]["add1:0"], 2) self.assertLen(depth_map["dense2:0"], 1) self.assertEqual(depth_map["dense2:0"]["add1:0"], 2)
def test_multi_input(self): ops = [ new_op(op_name="dense0", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu0", op_type=OpType.RELU, input_names=["dense0"]), new_op(op_name="dense1", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu1", op_type=OpType.RELU, input_names=["dense1"]), new_op(op_name="dense2", op_type=OpType.DENSE, op_kwargs={"features": 32}, input_names=["input"]), new_op(op_name="relu2", op_type=OpType.RELU, input_names=["dense2"]), new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), ] graph = new_graph(input_names=["input"], output_names=["add0", "add1"], ops=ops) subgraph_spec = [ SubgraphNode(op=new_op( op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])), SubgraphNode(op=new_op( op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])), SubgraphNode(op=new_op( op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])), SubgraphNode(op=new_op(op_name="add0", op_type=OpType.ADD, input_names=["relu0", "relu1"]), output_names=["add0"]), SubgraphNode(op=new_op(op_name="add1", op_type=OpType.ADD, input_names=["relu1", "relu2"]), output_names=["add1"]), ] replaced_graph = replace_subgraph(graph, subgraph_spec) inp = {"input": jnp.ones((10, 32, 32, 3))} subgraph_model = SubgraphModel(replaced_graph, {}, {}, inp, subgraph_spec) lp = linear.LinopProperty().infer(subgraph_model) pairings = lp.pairings self.assertLen(pairings, 2) self.assertIn("add0:0", pairings) self.assertLen(pairings["add0:0"], 2) self.assertIn("dense0:0", pairings["add0:0"]) self.assertIn("dense1:0", pairings["add0:0"]) self.assertIn("add1:0", pairings) self.assertLen(pairings["add1:0"], 2) self.assertIn("dense1:0", pairings["add1:0"]) self.assertIn("dense2:0", pairings["add1:0"])
def efficietnet(num_classes, config, blocks): """Returns a graph for ResNet V1.""" drop_connect_rate = .2 ops = [] constants = {} new_blocks = [] append = functools.partial(append_op, ops) stem_filters = round_filters(32, config) append(op_name="stem/conv0", op_type=OpType.CONV, op_kwargs={ "features": stem_filters, "kernel_size": 3, "strides": 2, }, input_names=["input"]) append(op_name="stem/bn1", op_type=OpType.BATCH_NORM) append(op_name="stem/swish2", op_type=OpType.SWISH) input_name = ops[-1].name block_num = 0 num_blocks_total = len(blocks) for block in blocks: drop_rate = drop_connect_rate * float(block_num) / num_blocks_total _, stride, _, output_filters = _extract_block_info(block.name) graph, new_constants, new_block = mbconv_layer( block=block, input_name=input_name, block_id=block_num, output_filters=output_filters, stride=stride, layer_drop_rate=drop_rate) input_name = graph.output_names[0] constants.update(new_constants) new_blocks.append(new_block) ops.extend(graph.ops) block_num += 1 top_filters = round_filters(1280, config) append(op_name="head/conv0", op_type=OpType.CONV, op_kwargs={ "features": top_filters, "kernel_size": 1, "strides": 1, }) append(op_name="head/bn1", op_type=OpType.BATCH_NORM) append(op_name="head/swish2", op_type=OpType.SWISH) append(op_name="head/pool3", op_type=OpType.AVG_POOL, input_kwargs={"window_shape": 0}) if config.dropout_rate and config.dropout_rate > 0: append(op_name="head/dropout4", op_type=OpType.DROPOUT, op_kwargs={"rate": config.dropout_rate}) append(op_name="head/dense5", op_type=OpType.DENSE, op_kwargs={"features": num_classes}) append(op_name="head/out", op_type=OpType.FLATTEN) graph = new_graph(input_names=["input"], output_names=["head/out"], ops=ops) return graph, constants, new_blocks