Esempio n. 1
0
    def setUp(self):
        super().setUp()

        self.graph_dense = new_graph(
            ["input"],
            ["output"],
            [
                new_op(
                    op_name="output",
                    op_type=OpType.SOFTMAX,
                    # op_kwargs={"features": 10},
                    input_names=["input"])
            ])
        state_dense = Model(self.graph_dense).init(
            random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))})
        self.subgraph_dense = SubgraphModel(self.graph_dense, None,
                                            state_dense,
                                            {"input": jnp.ones((5, 5, 5))})
        self.lp_dense = linear.LinopProperty().infer(self.subgraph_dense)

        self.graph_conv = new_graph(["input"], ["output"], [
            new_op(op_name="output",
                   op_type=OpType.CONV,
                   op_kwargs={
                       "features": 10,
                       "kernel_size": [3, 3]
                   },
                   input_names=["input"])
        ])
        state_conv = Model(self.graph_conv).init(
            random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))})
        self.subgraph_conv = SubgraphModel(self.graph_conv, None, state_conv,
                                           {"input": jnp.ones((5, 5, 5))})
        self.lp_conv = linear.LinopProperty().infer(self.subgraph_conv)
Esempio n. 2
0
def mlp_block(dropout, mlp_factor):
  """MLP block in the encoder block."""
  ops = []
  append = functools.partial(append_op, ops)
  use_dropout = dropout > 1e-3

  input_name = "input"
  append(op_name="dense0",
         op_type=OpType.DENSE,
         op_kwargs={
             "features": f"S:-1*{mlp_factor}",
             "kernel_init": "I:xavier_uniform",
             "bias_init": "I:normal:stddev:1e-6"
         },
         input_names=[input_name])

  append(op_name="gelu1",
         op_type=OpType.GELU)
  if use_dropout:
    append(op_name="dropout",
           op_type=OpType.DROPOUT,
           op_kwargs={"rate": dropout})
  append(op_name="dense2",
         op_type=OpType.DENSE,
         op_kwargs={
             "features": f"S:-1%{mlp_factor}",
             "kernel_init": "I:xavier_uniform",
             "bias_init": "I:normal:stddev:1e-6"
         })
  output_name = ops[-1].name
  graph = new_graph(input_names=[input_name], output_names=[output_name],
                    ops=ops)
  return Block(name=f"mlp_block{'_dropout' if use_dropout else ''}",
               graph=graph, constants={})
Esempio n. 3
0
def encoder(
    input_name,
    blocks,
    dropout,
):
  """Encoder for ViT."""
  ops = []
  constants = {}
  new_blocks = []
  append = functools.partial(append_op, ops)

  encoder_input_name = input_name
  for block_id, block in enumerate(blocks):
    graph, new_constants, new_block = encoder_block(
        input_name=encoder_input_name,
        block_id=block_id,
        block=block,
        dropout=dropout)
    encoder_input_name = graph.output_names[0]
    constants.update(new_constants)
    new_blocks.append(new_block)
    ops.extend(graph.ops)
  append(op_name="transformer/encoder_norm",
         op_type=OpType.LAYER_NORM)
  output_name = ops[-1].name
  graph = new_graph(
      input_names=[input_name], output_names=[output_name], ops=ops)
  return graph, constants, new_blocks
Esempio n. 4
0
  def _update_pairings(self, op, in_shapes):
    """Updates pairings with the property for a single op.

    Args:
      op: The op for which to infer the pairing property.
      in_shapes: the shapes of the input tensors.
    """
    assert len(op.input_names) == len(in_shapes)
    input_values = {
        input_name: jnp.ones(in_shape)
        for input_name, in_shape in zip(op.input_names, in_shapes)
    }
    output_names = [f"{op.name}:{i}" for i in range(op.num_outputs)]
    graph = new_graph(op.input_names, output_names, [op])
    model = Model(graph)
    state = model.init(jax.random.PRNGKey(0), input_values)
    pairings = GraphPairings.infer(
        model, input_values, state, abstract=False).pairings

    new_pairings = {}
    for output_idx, output_name in enumerate(output_names):
      new_pairings[output_idx] = {}
      for input_idx, input_name in enumerate(op.input_names):
        new_pairings[output_idx][input_idx] = pairings[output_name][input_name]

    key = self.hash(op, in_shapes)
    self.pairings[key] = new_pairings
Esempio n. 5
0
 def test_unsatisfy(self):
     # This test removes the last dense layer, so the new graph should be less
     # deep.
     graph = new_graph(input_names=["input"],
                       output_names=["fc/relu"],
                       ops=self.graph.ops)
     subgraph_model = SubgraphModel(graph, self.constants, {}, {})
     self.assertFalse(self.dp.verify(subgraph_model))
Esempio n. 6
0
    def test_equal(self):
        """Tests whether the fingerprint is the same for equivalent graphs.

    The ops have different names and also have different topological sort.
    """
        ops1 = [
            new_op(op_name="dense",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="conv",
                   op_type=OpType.CONV,
                   op_kwargs={
                       "features": 32,
                       "kernel_size": [3]
                   },
                   input_names=["input"]),
            new_op(op_name="output",
                   op_type=OpType.ADD,
                   input_names=["dense", "conv"]),
        ]
        graph1 = new_graph(["input"], ["output"], ops1)

        ops2 = [
            new_op(op_name="conv2",
                   op_type=OpType.CONV,
                   op_kwargs={
                       "features": 32,
                       "kernel_size": [3]
                   },
                   input_names=["input"]),
            new_op(op_name="dense2",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="output",
                   op_type=OpType.ADD,
                   input_names=["dense2", "conv2"]),
        ]
        graph2 = new_graph(["input"], ["output"], ops2)

        input_dict = {"input": jnp.ones((5, 5, 5))}
        fingerprint1 = fingerprint.fingerprint_graph(graph1, {}, input_dict)
        fingerprint2 = fingerprint.fingerprint_graph(graph2, {}, input_dict)
        self.assertEqual(fingerprint1, fingerprint2)
Esempio n. 7
0
def mb_conv_block(expand_ratio, stride, kernel_size, output_filters):
    """Returns an Efficientnet MBConvBlock."""
    input_name = "input"
    ops = []
    append = functools.partial(append_op, ops)

    # Expand.
    if expand_ratio > 1:
        append(op_name="mb_conv/expand/conv0",
               op_type=OpType.CONV,
               op_kwargs={
                   "features": f"S:-1*{expand_ratio}",
                   "kernel_size": 1,
                   "strides": 1,
               },
               input_names=[input_name])
        append(op_name="mb_conv/expand/bn1", op_type=OpType.BATCH_NORM)
        append(op_name="mb_conv/expand/swish2", op_type=OpType.SWISH)
        input_name = ops[-1].name

    # Depthwise conv.
    append(op_name="mb_conv/dw/conv0",
           op_type=OpType.CONV,
           op_kwargs={
               "features": "S:-1",
               "feature_group_count": "S:-1",
               "kernel_size": kernel_size,
               "strides": stride,
           },
           input_names=[input_name])
    append(op_name="mb_conv/dw/bn1", op_type=OpType.BATCH_NORM)
    append(op_name="mb_conv/dw/swish2", op_type=OpType.SWISH)

    # Squeeze and excitation.
    input_name = ops[-1].name
    se_ops = squeeze_excite(input_name, expand_ratio * 4)
    ops.extend(se_ops)

    # Output.
    append(op_name="mb_conv/output/conv0",
           op_type=OpType.CONV,
           op_kwargs={
               "features":
               "S:input:-1" if not output_filters else output_filters,
               "kernel_size": 1,
               "strides": 1,
           })
    append(op_name="mb_conv/output/bn1", op_type=OpType.BATCH_NORM)
    output_name = ops[-1].name
    graph = new_graph(input_names=["input"],
                      output_names=[output_name],
                      ops=ops)
    return Block(
        name=f"mbconv_expand{expand_ratio}_stride{stride}_kernel{kernel_size}_"
        f"outputfilters{output_filters}_",
        graph=graph)
Esempio n. 8
0
  def __init__(self,
               graph,
               constants,
               state,
               inputs,
               subgraph = None):
    self.graph = graph
    self.constants = constants
    self.state = state
    self.inputs = inputs
    self.subgraph: SubgraphSpec = subgraph if subgraph else []

    self.input_names = None
    self.output_names = None
    self.original_outputs = graph.output_names

    if subgraph:
      self._subgraph_to_names()

      # graph for graph inputs -> subg inputs
      self.subg_inputs_graph = copy.deepcopy(graph)
      self.subg_inputs_graph.output_names = self.input_names
      self.subg_inputs_model = Model(self.subg_inputs_graph, self.constants)
      self.subg_inputs = None

      # graph for graph inputs -> subg outputs
      self.subg_outputs_graph = copy.deepcopy(graph)
      self.subg_outputs_graph.output_names = self.output_names
      self.subg_outputs_model = Model(self.subg_outputs_graph, self.constants)
      self.subg_outputs = None

      # graph for subg inputs -> subg outputs
      subg_ops = [node.op for node in subgraph]
      self.subg_graph = new_graph(self.input_names, self.output_names, subg_ops)
      self.subg_model = Model(self.subg_graph, self.constants)
    else:
      self.input_names = [
          canonicalize_tensor_name(name) for name in graph.input_names
      ]
      self.output_names = [
          canonicalize_tensor_name(name) for name in graph.output_names
      ]

      # subg inputs = inputs to the graph
      self.subg_inputs_graph = None
      self.subg_inputs_model = None
      self.subg_inputs = inputs

      # graph for graph inputs -> subg outputs
      self.subg_outputs_graph = copy.deepcopy(graph)
      self.subg_outputs_model = Model(self.subg_outputs_graph, self.constants)
      self.subg_outputs = None

      # subg outputs = full graph outputs
      self.subg_graph = self.subg_outputs_graph
      self.subg_model = self.subg_outputs_model
Esempio n. 9
0
    def test_abstract(self):
        graphs = []

        conv_op = functools.partial(new_op,
                                    op_type=OpType.CONV,
                                    op_kwargs={
                                        "features": 10,
                                        "kernel_size": [3, 3]
                                    })
        dense_op = functools.partial(new_op,
                                     op_type=OpType.DENSE,
                                     op_kwargs={
                                         "features": 10,
                                     })

        for op_type in [
                OpType.RELU, OpType.SOFTMAX, OpType.LAYER_NORM,
                OpType.BATCH_NORM
        ]:
            for op_ctr in [conv_op, dense_op]:
                graphs.append([
                    op_ctr(input_names=["input"], op_name="other"),
                    new_op(op_name="output",
                           op_type=op_type,
                           input_names=["other"])
                ])
                graphs.append([
                    new_op(op_name="other",
                           op_type=op_type,
                           input_names=["input"]),
                    op_ctr(input_names=["other"], op_name="output"),
                ])

        input_tensor = {"input": jnp.ones((5, 5, 5, 5))}
        for graph in graphs:
            graph = new_graph(["input"], ["output"], graph)
            state = Model(graph).init(random.PRNGKey(1), input_tensor)

            # Make all the kernels positive, otherwise, the ReLU might zero out the
            # entire tensor.
            state = jax.tree_util.tree_map(abs, state)

            subg_model = SubgraphModel(graph, None, state, input_tensor)
            lp_abstract = linear.LinopProperty().infer(subg_model,
                                                       abstract=True)
            lp_concrete = linear.LinopProperty().infer(subg_model,
                                                       abstract=False)
            pairings_concerete = lp_concrete.pairings["output"][
                "input"].mappings
            pairings_abstract = lp_abstract.pairings["output"][
                "input"].mappings

            print("concrete:", pairings_concerete)
            print("abstract:", pairings_abstract)
            self.assertTrue(
                ((pairings_abstract - pairings_concerete) == 0).all())
Esempio n. 10
0
 def test_satisfy(self):
     # This test removes the last dense layer, so the old graph should be more
     # deep.
     ops = self.graph.ops[:-1]
     ops[-1].name = "fc/logits"
     graph = new_graph(input_names=["input"],
                       output_names=["fc/logits"],
                       ops=ops)
     subgraph_model = SubgraphModel(graph, self.constants, {}, {})
     dp = depth.DepthProperty().infer(subgraph_model)
     self.assertTrue(dp.verify(self.subgraph_model))
Esempio n. 11
0
 def test_unsatisfy(self):
     # This test removes the last dense layer, so the new graph should have a
     # different shape (and therefore not satisfy the inferred shape property).
     graph = new_graph(input_names=["input"],
                       output_names=["fc/relu"],
                       ops=self.graph.ops)
     state = Model(graph,
                   self.constants).init(random.PRNGKey(0),
                                        {"input": jnp.ones((5, 32, 32, 3))})
     subgraph_model = SubgraphModel(graph, self.constants, state,
                                    {"input": jnp.ones((5, 32, 32, 3))})
     self.assertFalse(self.sp.verify(subgraph_model))
Esempio n. 12
0
    def test_not_equal(self):
        """Tests whether the fingerprint is different for non-equivalent graphs."""
        ops1 = [
            new_op(op_name="dense0",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="dense1",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="output",
                   op_type=OpType.ADD,
                   input_names=["dense0", "dense1"]),
        ]
        graph1 = new_graph(["input"], ["output"], ops1)

        ops2 = [
            new_op(op_name="conv2",
                   op_type=OpType.CONV,
                   op_kwargs={
                       "features": 32,
                       "kernel_size": [3]
                   },
                   input_names=["input"]),
            new_op(op_name="dense2",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="output",
                   op_type=OpType.ADD,
                   input_names=["dense2", "conv2"]),
        ]
        graph2 = new_graph(["input"], ["output"], ops2)

        input_dict = {"input": jnp.ones((5, 5, 5))}
        fingerprint1 = fingerprint.fingerprint_graph(graph1, {}, input_dict)
        fingerprint2 = fingerprint.fingerprint_graph(graph2, {}, input_dict)
        self.assertNotEqual(fingerprint1, fingerprint2)
Esempio n. 13
0
def conv_net(in_features, out_features, num_classes, blocks=None):
    """Graph for 3-layer CNN."""
    if not blocks:
        blocks = [block_type() for block_type in BLOCK_TYPES]

    input_name = "input"
    new_blocks = []
    ops = [
        new_op(op_name="proj",
               op_type=OpType.CONV,
               op_kwargs={
                   "features": in_features,
                   "kernel_size": 1,
               },
               input_names=[input_name])
    ]
    constants = {}

    block_input_name = ops[-1].name
    for idx, block in enumerate(blocks):
        block = block.instantiate(input_names=[block_input_name],
                                  instance_id=idx)
        new_blocks.append(block)
        constants.update(block.constants)
        ops.extend(block.graph.ops)
        block_input_name = ops[-1].name

    constants.update({
        "out_features": out_features,
        "num_classes": num_classes
    })
    ops.extend([
        new_op(op_name="flatten",
               op_type=OpType.FLATTEN,
               input_names=[ops[-1].name]),
        new_op(op_name="fc/dense",
               op_type=OpType.DENSE,
               op_kwargs={"features": "K:out_features"},
               input_names=["flatten"]),
        new_op(op_name="fc/relu",
               op_type=OpType.RELU,
               input_names=["fc/dense"]),
        new_op(op_name="fc/logits",
               op_type=OpType.DENSE,
               op_kwargs={"features": "K:num_classes"},
               input_names=["fc/relu"])
    ])
    graph = new_graph(input_names=[input_name],
                      output_names=["fc/logits"],
                      ops=ops)
    return graph, constants, new_blocks
Esempio n. 14
0
 def test_identical(self):
     """Tests whether the fingerprint is the same for identical graphs."""
     ops = [
         new_op(op_name="dense0",
                op_type=OpType.DENSE,
                op_kwargs={"features": 32},
                input_names=["input"]),
         new_op(op_name="dense1",
                op_type=OpType.DENSE,
                op_kwargs={"features": 32},
                input_names=["input"]),
         new_op(op_name="output",
                op_type=OpType.ADD,
                input_names=["dense0", "dense1"]),
     ]
     graph = new_graph(["input"], ["output"], ops)
     input_dict = {"input": jnp.ones((5, 5, 5))}
     fingerprint1 = fingerprint.fingerprint_graph(graph, {}, input_dict)
     fingerprint2 = fingerprint.fingerprint_graph(graph, {}, input_dict)
     self.assertEqual(fingerprint1, fingerprint2)
Esempio n. 15
0
def mbconv_layer(block, input_name, block_id, output_filters, stride,
                 layer_drop_rate):
    """Returns a MBConv layer."""

    prefix = f"mbconv{block_id}"

    block = block.instantiate(input_names=[input_name], instance_id=block_id)
    constants = block.constants
    ops = list(block.graph.ops)

    if not output_filters and stride == 1:
        append_op(ops,
                  op_name=f"{prefix}/skip",
                  op_type=OpType.ADD,
                  input_kwargs={"layer_drop_rate": layer_drop_rate},
                  input_names=[ops[-1].name, input_name])

    output_name = ops[-1].name
    graph = new_graph(input_names=[input_name],
                      output_names=[output_name],
                      ops=ops)
    return graph, constants, block
Esempio n. 16
0
def encoder_block(
    input_name,
    block_id,
    block,
    dropout,
):
  """Returns an encoder block."""

  prefix = f"encoder{block_id}"

  ops = []
  append = functools.partial(append_op, ops)
  use_dropout = dropout > 1e-3

  res_input = input_name

  append(op_name=f"{prefix}/layernorm0",
         op_type=OpType.LAYER_NORM,
         input_names=[input_name])

  block = block.instantiate(
      input_names=[ops[-1].name],
      instance_id=block_id)
  ops.extend(block.graph.ops)
  constants = block.constants

  if use_dropout:
    append(op_name=f"{prefix}/dropout1",
           op_type=OpType.DROPOUT,
           op_kwargs={"rate": dropout})
  append(op_name=f"{prefix}/residual1",
         op_type=OpType.ADD,
         input_names=[res_input, ops[-1].name])

  output_name = ops[-1].name
  graph = new_graph(input_names=[input_name], output_names=[output_name],
                    ops=ops)
  return graph, constants, block
Esempio n. 17
0
def conv_block():
    """Makes a conv block parameterized by the number of features."""
    ops = [
        new_op(op_name="conv",
               op_type=OpType.CONV,
               op_kwargs={
                   "features": "S:-1*2",
                   "kernel_size": 3
               },
               input_names=["input"]),
        new_op(op_name="relu", op_type=OpType.RELU, input_names=["conv"]),
        new_op(op_name="avg_pool",
               op_type=OpType.AVG_POOL,
               input_names=["relu"],
               input_kwargs={
                   "window_shape": 2,
                   "strides": 2
               }),
    ]

    graph = new_graph(input_names=["input"],
                      output_names=["avg_pool"],
                      ops=ops)
    return Block(name="conv_layer", graph=graph)
Esempio n. 18
0
def mhdpa_block(spatial):
  """Multi headed dot product (self) attention block in encoder block."""
  input_name = "input"
  constants = {"num_heads": None, "head_dim": None}
  width = "S:input:-1"
  num_heads = "K:num_heads"
  head_dim = "K:head_dim"

  ops = []
  append = functools.partial(append_op, ops)

  append(
      op_name="value/pre",
      op_type=OpType.DENSE,
      op_kwargs={
          "features": "S:-1",
          "kernel_init": "I:xavier_uniform",
          "bias_init": "I:zeros"
      },
      input_names=[input_name])
  append(
      op_name="key/pre",
      op_type=OpType.DENSE,
      op_kwargs={
          "features": "S:-1",
          "kernel_init": "I:xavier_uniform",
          "bias_init": "I:zeros"
      },
      input_names=[input_name])
  append(
      op_name="query/pre",
      op_type=OpType.DENSE,
      op_kwargs={
          "features": "S:-1",
          "kernel_init": "I:xavier_uniform",
          "bias_init": "I:zeros"
      },
      input_names=[input_name])

  # spatial:  [b, h, w, width] -> [b, h*w, num_heads, head_dim]
  # original: [b, h*w, width]  -> [b, h*w, num_heads, head_dim]
  new_shape = ["B", -1, num_heads, head_dim]
  append(
      op_name="query",
      op_type=OpType.RESHAPE,
      input_kwargs={"new_shape": new_shape},
      input_names=["query/pre"])
  append(
      op_name="key",
      op_type=OpType.RESHAPE,
      input_kwargs={"new_shape": new_shape},
      input_names=["key/pre"])
  append(
      op_name="value",
      op_type=OpType.RESHAPE,
      input_kwargs={"new_shape": new_shape},
      input_names=["value/pre"])

  append(op_name="query/scale",
         op_type=OpType.SCALAR_MUL,
         input_names=["query"])

  # attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key)
  append(
      op_name="attn_weight",
      op_type=OpType.EINSUM,
      input_kwargs={"sum": "...qhd,...khd->...hqk"},
      input_names=["query/scale", "key"])

  append(
      op_name="attn_weight/softmax",
      op_type=OpType.SOFTMAX,
      input_kwargs={"axis": -1})

  # attn_values = jnp.einsum('...hqk,...khd->...qhd', attn_weights, value)
  append(
      op_name="attn_value",
      op_type=OpType.EINSUM,
      input_kwargs={"sum": "...hqk,...khd->...qhd"},
      input_names=["attn_weight/softmax", "value"])

  # back to the original inputs dimensions
  if spatial:
    # [b, h*w, num_heads, head_dim] -> [b, h, w, width]
    new_shape = ["B", "S:input:1", "S:input:2", width]
  else:
    # [b, h*w, num_heads, head_dim] -> [b, h*w, width]
    new_shape = ["B", -1, width]

  append(
      op_name="attn_value/reshape",
      op_type=OpType.RESHAPE,
      input_kwargs={"new_shape": new_shape})
  append(
      op_name="out",
      op_type=OpType.DENSE,
      op_kwargs={
          "features": "S:-1",
          "kernel_init": "I:xavier_uniform",
          "bias_init": "I:zeros",
      })

  output_name = ops[-1].name
  graph = new_graph(input_names=[input_name], output_names=[output_name],
                    ops=ops)
  return Block(name=f"mhdpa_block{'_spatial' if spatial else ''}",
               graph=graph, constants=constants)
Esempio n. 19
0
    def instantiate(self, input_names, instance_id=None, constants=None):
        """Instantiates a version of the block with unique names.

    This method uses the names of graph and constants from the initial
    definition of the block (__init__) , so that one can instantiate from any
    derived block with same effect, e.g., if we have:
      init_block = block.__init__(name="conv_layer", ...)
      block0 = init_block.instantiate(instance_id=0, ...)
    then:
      block1 = init_block.instantiate(instance_id=1, ...)
    will have the same effect as:
      block1 = block0.instantiate(instance_id=1, ...)
    The one caveat is that the default values for unspecified constants are
    inherited from the instantiating block (instead of the initial definition).

    Args:
      input_names: The input tensor names the instantiated block will consume.
      instance_id: An id to make the names in the instantiated block unique.
        The id should be unique within a graph.
      constants: Updated parameters for the instantiated block.

    Returns:
      An instantiated block.

    Raises:
      ValueError: if the number of input names provided does not equal the
        number of inputs consumed by the graph.
    """
        if len(input_names) != len(self.base_graph.input_names):
            raise ValueError("Wrong number of inputs provided.")

        prefix = ""
        if self.name: prefix += self.name
        if instance_id is not None: prefix += str(instance_id)
        if prefix: prefix += "/"

        if not constants: constants = dict(self.base_constants)

        new_input_names = input_names
        updated_names = {
            o: n
            for o, n in zip(self.base_graph.input_names, new_input_names)
        }
        inputs_names = [
            canonicalize_tensor_name(n) for n in self.base_graph.input_names
        ]
        updated_names.update(
            {o: n
             for o, n in zip(inputs_names, new_input_names)})

        # Update ops.
        new_ops = []
        for op in self.base_graph.ops:
            # Update all input tensor names.
            # Any internal inputs (i.e., anything that is not a graph input) needs to
            # be updated with the prefix.
            new_inputs = []
            for inp in op.input_names:
                try:
                    idx = inputs_names.index(inp)
                    new_inputs.append(new_input_names[idx])
                except ValueError:
                    new_inputs.append(f"{prefix}{inp}")

            # Update symbolic constant names in input_kwargs and op_kwargs.
            new_kwargs = []
            for kwargs in [op.input_kwargs, op.op_kwargs]:
                nk = {
                    k: _prefix_symbolic(v, prefix, constants, updated_names)
                    for k, v in kwargs.items()
                }
                new_kwargs.append(nk)

            new_ops.append(
                new_op(op_name=f"{prefix}{op.name}",
                       op_type=op.type,
                       input_names=new_inputs,
                       input_kwargs=new_kwargs[0],
                       op_kwargs=new_kwargs[1],
                       num_outputs=op.num_outputs))

        # Update constants and prefix symbolic constant names.
        old_constants = dict(self.base_constants)
        if constants: old_constants.update(constants)
        new_constants = {f"{prefix}{k}": v for k, v in old_constants.items()}

        # Prefix graph output names.
        new_output_names = [
            f"{prefix}{on}" for on in self.base_graph.output_names
        ]

        graph = new_graph(ops=new_ops,
                          input_names=new_input_names,
                          output_names=new_output_names)
        return Block(name=self.name,
                     graph=graph,
                     constants=new_constants,
                     base_graph=self.base_graph,
                     base_constants=old_constants)
Esempio n. 20
0
def vit(
    blocks,
    patch_size,
    image_size,
    width,
    dropout,
    num_classes,
    spatial,
):
  """Graph for ViT."""
  assert image_size % patch_size == 0
  ops = []
  append = functools.partial(append_op, ops)
  use_dropout = dropout > 1e-3

  append(op_name="embedding",
         op_type=OpType.CONV,
         op_kwargs={
             "features": width,
             "kernel_size": [patch_size, patch_size],
             "strides": [patch_size, patch_size],
             "padding": "VALID",
         },
         input_names=["input"])

  if spatial:
    # could also be [1, sequence_size, sequence_size, width]
    pos_embedding_shape = [
        1, f"S:{ops[-1].name}:1", f"S:{ops[-1].name}:2", f"S:{ops[-1].name}:3"
    ]
  else:
    append(
        op_name="reshape",
        op_type=OpType.RESHAPE,
        input_kwargs={"new_shape": ["B", -1, "S:-1"]})

    # could also be [1, sequence_size**2, width]
    pos_embedding_shape = [1, f"S:{ops[-1].name}:1", f"S:{ops[-1].name}:2"]

  append(op_name="transformer/pos_embedding",
         op_type=OpType.PARAM,
         input_kwargs={
             "shape": pos_embedding_shape,
             "init_fn": f"I:normal:stddev:{1/math.sqrt(width):.03f}"
         },
         input_names=[])
  append(op_name="transformer/pos_embedding/add",
         op_type=OpType.ADD,
         input_names=[ops[-2].name, ops[-1].name])
  if use_dropout:
    append(op_name="transformer/dropout",
           op_type=OpType.DROPOUT,
           op_kwargs={"rate": dropout})

  graph, constants, blocks = encoder(
      input_name=ops[-1].name,
      blocks=blocks,
      dropout=dropout)
  ops.extend(graph.ops)

  if spatial:
    append(op_name="reshape",
           op_type=OpType.RESHAPE,
           input_kwargs={"new_shape": ["B", -1, "S:-1"]})

  append(op_name="gap",
         op_type=OpType.MEAN,
         input_kwargs={"axis": 1})
  append(op_name="head",
         op_type=OpType.DENSE,
         op_kwargs={
             "features": num_classes,
             "kernel_init": "I:zeros"
         })

  graph = new_graph(input_names=["input"], output_names=["head"], ops=ops)
  return graph, constants, blocks
Esempio n. 21
0
    def test_multi_input(self):
        ops = [
            new_op(op_name="dense0",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu0",
                   op_type=OpType.RELU,
                   input_names=["dense0"]),
            new_op(op_name="dense1",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu1",
                   op_type=OpType.RELU,
                   input_names=["dense1"]),
            new_op(op_name="dense2",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu2",
                   op_type=OpType.RELU,
                   input_names=["dense2"]),
            new_op(op_name="add0",
                   op_type=OpType.ADD,
                   input_names=["relu0", "relu1"]),
            new_op(op_name="add1",
                   op_type=OpType.ADD,
                   input_names=["relu1", "relu2"]),
        ]
        graph = new_graph(input_names=["input"],
                          output_names=["add0", "add1"],
                          ops=ops)
        subgraph_spec = [
            SubgraphNode(op=new_op(
                op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])),
            SubgraphNode(op=new_op(
                op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])),
            SubgraphNode(op=new_op(
                op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])),
            SubgraphNode(op=new_op(op_name="add0",
                                   op_type=OpType.ADD,
                                   input_names=["relu0", "relu1"]),
                         output_names=["add0"]),
            SubgraphNode(op=new_op(op_name="add1",
                                   op_type=OpType.ADD,
                                   input_names=["relu1", "relu2"]),
                         output_names=["add1"]),
        ]
        replaced_graph = replace_subgraph(graph, subgraph_spec)
        subgraph_model = SubgraphModel(replaced_graph, {}, {}, {},
                                       subgraph_spec)
        dp = depth.DepthProperty().infer(subgraph_model)
        depth_map = dp.depth_map

        self.assertLen(depth_map, 3)
        self.assertIn("dense0:0", depth_map)
        self.assertIn("dense1:0", depth_map)
        self.assertIn("dense2:0", depth_map)
        self.assertLen(depth_map["dense0:0"], 1)
        self.assertEqual(depth_map["dense0:0"]["add0:0"], 2)
        self.assertLen(depth_map["dense1:0"], 2)
        self.assertEqual(depth_map["dense1:0"]["add0:0"], 2)
        self.assertEqual(depth_map["dense1:0"]["add1:0"], 2)
        self.assertLen(depth_map["dense2:0"], 1)
        self.assertEqual(depth_map["dense2:0"]["add1:0"], 2)
Esempio n. 22
0
    def test_multi_input(self):
        ops = [
            new_op(op_name="dense0",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu0",
                   op_type=OpType.RELU,
                   input_names=["dense0"]),
            new_op(op_name="dense1",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu1",
                   op_type=OpType.RELU,
                   input_names=["dense1"]),
            new_op(op_name="dense2",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu2",
                   op_type=OpType.RELU,
                   input_names=["dense2"]),
            new_op(op_name="add0",
                   op_type=OpType.ADD,
                   input_names=["relu0", "relu1"]),
            new_op(op_name="add1",
                   op_type=OpType.ADD,
                   input_names=["relu1", "relu2"]),
        ]
        graph = new_graph(input_names=["input"],
                          output_names=["add0", "add1"],
                          ops=ops)
        subgraph_spec = [
            SubgraphNode(op=new_op(
                op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])),
            SubgraphNode(op=new_op(
                op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])),
            SubgraphNode(op=new_op(
                op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])),
            SubgraphNode(op=new_op(op_name="add0",
                                   op_type=OpType.ADD,
                                   input_names=["relu0", "relu1"]),
                         output_names=["add0"]),
            SubgraphNode(op=new_op(op_name="add1",
                                   op_type=OpType.ADD,
                                   input_names=["relu1", "relu2"]),
                         output_names=["add1"]),
        ]
        replaced_graph = replace_subgraph(graph, subgraph_spec)
        inp = {"input": jnp.ones((10, 32, 32, 3))}
        subgraph_model = SubgraphModel(replaced_graph, {}, {}, inp,
                                       subgraph_spec)
        lp = linear.LinopProperty().infer(subgraph_model)
        pairings = lp.pairings

        self.assertLen(pairings, 2)
        self.assertIn("add0:0", pairings)
        self.assertLen(pairings["add0:0"], 2)
        self.assertIn("dense0:0", pairings["add0:0"])
        self.assertIn("dense1:0", pairings["add0:0"])
        self.assertIn("add1:0", pairings)
        self.assertLen(pairings["add1:0"], 2)
        self.assertIn("dense1:0", pairings["add1:0"])
        self.assertIn("dense2:0", pairings["add1:0"])
Esempio n. 23
0
def efficietnet(num_classes, config, blocks):
    """Returns a graph for ResNet V1."""

    drop_connect_rate = .2

    ops = []
    constants = {}
    new_blocks = []
    append = functools.partial(append_op, ops)

    stem_filters = round_filters(32, config)
    append(op_name="stem/conv0",
           op_type=OpType.CONV,
           op_kwargs={
               "features": stem_filters,
               "kernel_size": 3,
               "strides": 2,
           },
           input_names=["input"])
    append(op_name="stem/bn1", op_type=OpType.BATCH_NORM)
    append(op_name="stem/swish2", op_type=OpType.SWISH)

    input_name = ops[-1].name
    block_num = 0
    num_blocks_total = len(blocks)
    for block in blocks:
        drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
        _, stride, _, output_filters = _extract_block_info(block.name)

        graph, new_constants, new_block = mbconv_layer(
            block=block,
            input_name=input_name,
            block_id=block_num,
            output_filters=output_filters,
            stride=stride,
            layer_drop_rate=drop_rate)

        input_name = graph.output_names[0]
        constants.update(new_constants)
        new_blocks.append(new_block)
        ops.extend(graph.ops)

        block_num += 1

    top_filters = round_filters(1280, config)
    append(op_name="head/conv0",
           op_type=OpType.CONV,
           op_kwargs={
               "features": top_filters,
               "kernel_size": 1,
               "strides": 1,
           })
    append(op_name="head/bn1", op_type=OpType.BATCH_NORM)
    append(op_name="head/swish2", op_type=OpType.SWISH)
    append(op_name="head/pool3",
           op_type=OpType.AVG_POOL,
           input_kwargs={"window_shape": 0})
    if config.dropout_rate and config.dropout_rate > 0:
        append(op_name="head/dropout4",
               op_type=OpType.DROPOUT,
               op_kwargs={"rate": config.dropout_rate})
    append(op_name="head/dense5",
           op_type=OpType.DENSE,
           op_kwargs={"features": num_classes})
    append(op_name="head/out", op_type=OpType.FLATTEN)
    graph = new_graph(input_names=["input"],
                      output_names=["head/out"],
                      ops=ops)
    return graph, constants, new_blocks