Exemplo n.º 1
0
    def test_rewire(self):
        subgraph_spec = [
            SubgraphNode(op=new_op(op_name="conv_layer1/conv/1",
                                   op_type=OpType.CONV,
                                   op_kwargs={
                                       "features": 64,
                                       "kernel_size": [1, 1]
                                   },
                                   input_names=["conv_layer0/avg_pool"]), ),
            SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1",
                                   op_type=OpType.GELU,
                                   input_names=["conv_layer1/conv/1"]),
                         output_names=["conv_layer1/relu"])
        ]

        graph = replace_subgraph(self.graph, subgraph_spec)
        state = Model(graph,
                      self.constants).init(random.PRNGKey(0),
                                           {"input": jnp.ones((5, 32, 32, 3))})
        subgraph_model = SubgraphModel(graph, self.constants, state,
                                       {"input": jnp.ones(
                                           (5, 32, 32, 3))}, subgraph_spec)
        sp = shape.ShapeProperty().infer(subgraph_model)

        self.assertLen(sp.input_shapes, 1)
        self.assertIn("conv_layer0/avg_pool:0", sp.input_shapes)
        self.assertLen(sp.output_shapes, 2)
        self.assertIn("conv_layer1/gelu/1:0", sp.output_shapes)
        self.assertIn("conv_layer1/relu:0", sp.output_shapes)
Exemplo n.º 2
0
    def setUp(self):
        super().setUp()

        self.graph_dense = new_graph(
            ["input"],
            ["output"],
            [
                new_op(
                    op_name="output",
                    op_type=OpType.SOFTMAX,
                    # op_kwargs={"features": 10},
                    input_names=["input"])
            ])
        state_dense = Model(self.graph_dense).init(
            random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))})
        self.subgraph_dense = SubgraphModel(self.graph_dense, None,
                                            state_dense,
                                            {"input": jnp.ones((5, 5, 5))})
        self.lp_dense = linear.LinopProperty().infer(self.subgraph_dense)

        self.graph_conv = new_graph(["input"], ["output"], [
            new_op(op_name="output",
                   op_type=OpType.CONV,
                   op_kwargs={
                       "features": 10,
                       "kernel_size": [3, 3]
                   },
                   input_names=["input"])
        ])
        state_conv = Model(self.graph_conv).init(
            random.PRNGKey(0), {"input": jnp.ones((5, 5, 5))})
        self.subgraph_conv = SubgraphModel(self.graph_conv, None, state_conv,
                                           {"input": jnp.ones((5, 5, 5))})
        self.lp_conv = linear.LinopProperty().infer(self.subgraph_conv)
Exemplo n.º 3
0
    def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        self.model = Model(self.graph, self.constants)
        self.state = self.model.init(random.PRNGKey(0), jnp.ones(
            (1, 32, 32, 3)))

        self.subgraph = [
            subgraph.SubgraphNode(op=new_op(
                op_name="conv_layer1/conv/1",
                op_type=OpType.CONV,
                op_kwargs={
                    "features": 64,
                    "kernel_size": [1, 1]
                },
                input_names=["conv_layer0/avg_pool"]), ),
            subgraph.SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1",
                                            op_type=OpType.GELU,
                                            input_names=["conv_layer1/conv/1"
                                                         ]),
                                  output_names=["conv_layer1/relu"])
        ]
        self.new_graph = subgraph.replace_subgraph(self.graph, self.subgraph)
        self.new_model = Model(self.new_graph, self.constants)
        self.new_state = self.new_model.init(random.PRNGKey(0),
                                             jnp.ones((1, 32, 32, 3)))
Exemplo n.º 4
0
 def test_abstract_sequential_synthesizer_output_features(self):
   graph, constants, _ = cnn.CifarNet()
   subgraph_spec = [
       SubgraphNode(
           op=new_op(
               op_name="conv_layer1/conv",
               op_type=OpType.CONV,
               op_kwargs={
                   "features": "S:-1*2",
                   "kernel_size": [1, 1]
               },
               input_names=["conv_layer0/avg_pool"]),),
       SubgraphNode(
           op=new_op(
               op_name="conv_layer1/relu",
               op_type=OpType.RELU,
               input_names=["conv_layer1/conv"]),
           output_names=["conv_layer1/relu"])
   ]
   subgraph = replace_subgraph(graph, subgraph_spec)
   subgraph_model = SubgraphModel(subgraph, constants, None,
                                  {"input": jnp.zeros((5, 32, 32, 10))},
                                  subgraph_spec)
   sp = shape.ShapeProperty().infer(subgraph_model)
   syn = TestSequentialSynthesizer([(subgraph_model, [sp])], 0)
   self.assertEqual(syn.output_features_mul, 2)
   self.assertEqual(syn.output_features_div, 1)
Exemplo n.º 5
0
 def test_abstract_sequential_synthesizer_fail(self):
   graph, constants, _ = cnn.CifarNet()
   subgraph_spec = [
       SubgraphNode(
           op=new_op(
               op_name="conv_layer1/conv/1",
               op_type=OpType.CONV,
               op_kwargs={
                   "features": 64,
                   "kernel_size": [1, 1]
               },
               input_names=["conv_layer0/avg_pool"]),
           output_names=["conv_layer1/conv"]),
       SubgraphNode(
           op=new_op(
               op_name="conv_layer1/gelu/1",
               op_type=OpType.GELU,
               input_names=["conv_layer1/conv"]),
           output_names=["conv_layer1/relu"])
   ]
   subgraph = SubgraphModel(graph, constants, None,
                            {"input": jnp.zeros((5, 32, 32, 10))},
                            subgraph_spec)
   self.assertRaisesRegex(ValueError, ".*exactly one input.*",
                          TestSequentialSynthesizer, [(subgraph, [])], 0)
Exemplo n.º 6
0
    def test_rewire(self):
        # orig: conv, relu, pool, conv, relu, pool, flatten, dense, relu, dense
        # new:  conv, relu, pool, conv, gelu, pool, flatten, dense, relu, dense
        subgraph_spec = [
            SubgraphNode(op=new_op(op_name="conv_layer1/conv/1",
                                   op_type=OpType.CONV,
                                   op_kwargs={
                                       "features": 64,
                                       "kernel_size": [1, 1]
                                   },
                                   input_names=["conv_layer0/avg_pool"]), ),
            SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1",
                                   op_type=OpType.GELU,
                                   input_names=["conv_layer1/conv/1"]),
                         output_names=["conv_layer1/relu"])
        ]

        graph = replace_subgraph(self.graph, subgraph_spec)
        subgraph_model = SubgraphModel(graph, self.constants, {}, {},
                                       subgraph_spec)
        dp = depth.DepthProperty().infer(subgraph_model)

        depth_map = dp.depth_map
        self.assertLen(depth_map, 1)
        self.assertIn("conv_layer0/avg_pool:0", depth_map)
        self.assertLen(depth_map["conv_layer0/avg_pool:0"], 2)
        self.assertIn("conv_layer1/relu:0",
                      depth_map["conv_layer0/avg_pool:0"])
        self.assertEqual(
            depth_map["conv_layer0/avg_pool:0"]["conv_layer1/relu:0"], 1)
        self.assertIn("conv_layer1/gelu/1:0",
                      depth_map["conv_layer0/avg_pool:0"])
        self.assertEqual(
            depth_map["conv_layer0/avg_pool:0"]["conv_layer1/gelu/1:0"], 1)
Exemplo n.º 7
0
    def test_abstract(self):
        graphs = []

        conv_op = functools.partial(new_op,
                                    op_type=OpType.CONV,
                                    op_kwargs={
                                        "features": 10,
                                        "kernel_size": [3, 3]
                                    })
        dense_op = functools.partial(new_op,
                                     op_type=OpType.DENSE,
                                     op_kwargs={
                                         "features": 10,
                                     })

        for op_type in [
                OpType.RELU, OpType.SOFTMAX, OpType.LAYER_NORM,
                OpType.BATCH_NORM
        ]:
            for op_ctr in [conv_op, dense_op]:
                graphs.append([
                    op_ctr(input_names=["input"], op_name="other"),
                    new_op(op_name="output",
                           op_type=op_type,
                           input_names=["other"])
                ])
                graphs.append([
                    new_op(op_name="other",
                           op_type=op_type,
                           input_names=["input"]),
                    op_ctr(input_names=["other"], op_name="output"),
                ])

        input_tensor = {"input": jnp.ones((5, 5, 5, 5))}
        for graph in graphs:
            graph = new_graph(["input"], ["output"], graph)
            state = Model(graph).init(random.PRNGKey(1), input_tensor)

            # Make all the kernels positive, otherwise, the ReLU might zero out the
            # entire tensor.
            state = jax.tree_util.tree_map(abs, state)

            subg_model = SubgraphModel(graph, None, state, input_tensor)
            lp_abstract = linear.LinopProperty().infer(subg_model,
                                                       abstract=True)
            lp_concrete = linear.LinopProperty().infer(subg_model,
                                                       abstract=False)
            pairings_concerete = lp_concrete.pairings["output"][
                "input"].mappings
            pairings_abstract = lp_abstract.pairings["output"][
                "input"].mappings

            print("concrete:", pairings_concerete)
            print("abstract:", pairings_abstract)
            self.assertTrue(
                ((pairings_abstract - pairings_concerete) == 0).all())
Exemplo n.º 8
0
def conv_net(in_features, out_features, num_classes, blocks=None):
    """Graph for 3-layer CNN."""
    if not blocks:
        blocks = [block_type() for block_type in BLOCK_TYPES]

    input_name = "input"
    new_blocks = []
    ops = [
        new_op(op_name="proj",
               op_type=OpType.CONV,
               op_kwargs={
                   "features": in_features,
                   "kernel_size": 1,
               },
               input_names=[input_name])
    ]
    constants = {}

    block_input_name = ops[-1].name
    for idx, block in enumerate(blocks):
        block = block.instantiate(input_names=[block_input_name],
                                  instance_id=idx)
        new_blocks.append(block)
        constants.update(block.constants)
        ops.extend(block.graph.ops)
        block_input_name = ops[-1].name

    constants.update({
        "out_features": out_features,
        "num_classes": num_classes
    })
    ops.extend([
        new_op(op_name="flatten",
               op_type=OpType.FLATTEN,
               input_names=[ops[-1].name]),
        new_op(op_name="fc/dense",
               op_type=OpType.DENSE,
               op_kwargs={"features": "K:out_features"},
               input_names=["flatten"]),
        new_op(op_name="fc/relu",
               op_type=OpType.RELU,
               input_names=["fc/dense"]),
        new_op(op_name="fc/logits",
               op_type=OpType.DENSE,
               op_kwargs={"features": "K:num_classes"},
               input_names=["fc/relu"])
    ])
    graph = new_graph(input_names=[input_name],
                      output_names=["fc/logits"],
                      ops=ops)
    return graph, constants, new_blocks
Exemplo n.º 9
0
  def synthesize(self):
    """Returns a new subgraph."""
    subgraph_spec = self.subgraphs_and_props[0][0].subgraph
    subg_ops = [copy.deepcopy(node.op) for node in subgraph_spec]

    mutations = [
        self.delete,
        self.insert,
        self.mutate_field,
        lambda x: self.insert(self.delete(x)),
        self.swap]

    if self.use_automl_zero:
      mutations.append(lambda _: self.randomize())

    # Certain mutations may not be applicable for the selected subgraph, and
    # they will return None (e.g., if the subgraph is of size 1, we cannot
    # swap). So loop through all mutations in a random order until we find a
    # mutation that is applicable.
    random.shuffle(mutations)
    mutated_subg_ops = None
    while mutations and mutated_subg_ops is None:
      mutation = mutations.pop()
      mutated_subg_ops = mutation(subg_ops)
    if mutated_subg_ops is None:
      raise ValueError("Synthesis failed.")
    subg_ops = mutated_subg_ops

    prefix = f"gen{self.generation}/"
    if not subg_ops:
      subg_ops.append(new_op("dummy", OpType.IDENTITY, [self.input_name]))
    for op in subg_ops:
      op.name = f"{prefix}{op.type.name.lower()}"
    subgraph_spec = self.make_subgraph_spec(subg_ops)
    return self.make_subgraph_models(subgraph_spec)
Exemplo n.º 10
0
 def test_identical(self):
     """Tests whether the fingerprint is the same for identical graphs."""
     ops = [
         new_op(op_name="dense0",
                op_type=OpType.DENSE,
                op_kwargs={"features": 32},
                input_names=["input"]),
         new_op(op_name="dense1",
                op_type=OpType.DENSE,
                op_kwargs={"features": 32},
                input_names=["input"]),
         new_op(op_name="output",
                op_type=OpType.ADD,
                input_names=["dense0", "dense1"]),
     ]
     graph = new_graph(["input"], ["output"], ops)
     input_dict = {"input": jnp.ones((5, 5, 5))}
     fingerprint1 = fingerprint.fingerprint_graph(graph, {}, input_dict)
     fingerprint2 = fingerprint.fingerprint_graph(graph, {}, input_dict)
     self.assertEqual(fingerprint1, fingerprint2)
Exemplo n.º 11
0
    def test_multi_input_output(self):
        """Tests a subgraph substitution on a graph with multiple inputs / output ops.

    We use a ResNet model, which has skip connections. This test checks that the
    substitution produces the expected number of ops, and also that the newly
    produced graph is still executable.
    """

        graph, constants, _ = resnetv1.ResNet18(num_classes=10,
                                                input_resolution="small")
        model = Model(graph, constants)
        state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
        y = model.apply(state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))

        subg = [
            subgraph.SubgraphNode(
                op=new_op(op_name="subgraph/conv0",
                          op_type=OpType.CONV,
                          op_kwargs={
                              "features": 64,
                              "kernel_size": [1, 1]
                          },
                          input_names=["resnet11/skip/relu1"])),
            subgraph.SubgraphNode(
                op=new_op(op_name="subgraph/gelu1",
                          op_type=OpType.GELU,
                          input_names=["subgraph/conv0"]),
                output_names=["resnet_stride1_filtermul1_basic12/relu2"])
        ]
        new_graph = subgraph.replace_subgraph(graph, subg)

        # the subgraph is 2 ops (conv / gelu) replacing 3 ops (conv / bn / relu)
        self.assertLen(graph.ops, len(new_graph.ops) + 1)

        new_model = Model(new_graph, constants)
        new_state = new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))

        y = new_model.apply(new_state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))
Exemplo n.º 12
0
def conv_block():
    """Makes a conv block parameterized by the number of features."""
    ops = [
        new_op(op_name="conv",
               op_type=OpType.CONV,
               op_kwargs={
                   "features": "S:-1*2",
                   "kernel_size": 3
               },
               input_names=["input"]),
        new_op(op_name="relu", op_type=OpType.RELU, input_names=["conv"]),
        new_op(op_name="avg_pool",
               op_type=OpType.AVG_POOL,
               input_names=["relu"],
               input_kwargs={
                   "window_shape": 2,
                   "strides": 2
               }),
    ]

    graph = new_graph(input_names=["input"],
                      output_names=["avg_pool"],
                      ops=ops)
    return Block(name="conv_layer", graph=graph)
Exemplo n.º 13
0
def append_op(ops,
              op_name,
              op_type,
              input_names = None,
              input_kwargs = None,
              op_kwargs = None,
              num_outputs = 1):
  """Convenience function for append to a sequence of ops."""
  if input_names is None:
    input_names = [ops[-1].name]
  ops.append(
      new_op(op_name=op_name,
             op_type=op_type,
             input_names=input_names,
             input_kwargs=input_kwargs,
             op_kwargs=op_kwargs if op_kwargs else {},
             num_outputs=num_outputs))
Exemplo n.º 14
0
def append_op(
        ops,  # pylint: disable=dangerous-default-value
        op_name,
        op_type,
        input_names=None,
        input_kwargs=None,
        op_kwargs={},
        num_outputs=1):
    """Convenience function for append to a sequence of ops."""
    if not input_names:
        input_names = [ops[-1].name]
    default_op_kwargs = DEFAULT_OP_KWARGS.get(op_type, {})
    ops.append(
        new_op(op_name=op_name,
               op_type=op_type,
               input_names=input_names,
               input_kwargs=input_kwargs,
               op_kwargs={
                   **default_op_kwargs,
                   **op_kwargs
               },
               num_outputs=num_outputs))
Exemplo n.º 15
0
    def test_equal(self):
        """Tests whether the fingerprint is the same for equivalent graphs.

    The ops have different names and also have different topological sort.
    """
        ops1 = [
            new_op(op_name="dense",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="conv",
                   op_type=OpType.CONV,
                   op_kwargs={
                       "features": 32,
                       "kernel_size": [3]
                   },
                   input_names=["input"]),
            new_op(op_name="output",
                   op_type=OpType.ADD,
                   input_names=["dense", "conv"]),
        ]
        graph1 = new_graph(["input"], ["output"], ops1)

        ops2 = [
            new_op(op_name="conv2",
                   op_type=OpType.CONV,
                   op_kwargs={
                       "features": 32,
                       "kernel_size": [3]
                   },
                   input_names=["input"]),
            new_op(op_name="dense2",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="output",
                   op_type=OpType.ADD,
                   input_names=["dense2", "conv2"]),
        ]
        graph2 = new_graph(["input"], ["output"], ops2)

        input_dict = {"input": jnp.ones((5, 5, 5))}
        fingerprint1 = fingerprint.fingerprint_graph(graph1, {}, input_dict)
        fingerprint2 = fingerprint.fingerprint_graph(graph2, {}, input_dict)
        self.assertEqual(fingerprint1, fingerprint2)
Exemplo n.º 16
0
    def test_not_equal(self):
        """Tests whether the fingerprint is different for non-equivalent graphs."""
        ops1 = [
            new_op(op_name="dense0",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="dense1",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="output",
                   op_type=OpType.ADD,
                   input_names=["dense0", "dense1"]),
        ]
        graph1 = new_graph(["input"], ["output"], ops1)

        ops2 = [
            new_op(op_name="conv2",
                   op_type=OpType.CONV,
                   op_kwargs={
                       "features": 32,
                       "kernel_size": [3]
                   },
                   input_names=["input"]),
            new_op(op_name="dense2",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="output",
                   op_type=OpType.ADD,
                   input_names=["dense2", "conv2"]),
        ]
        graph2 = new_graph(["input"], ["output"], ops2)

        input_dict = {"input": jnp.ones((5, 5, 5))}
        fingerprint1 = fingerprint.fingerprint_graph(graph1, {}, input_dict)
        fingerprint2 = fingerprint.fingerprint_graph(graph2, {}, input_dict)
        self.assertNotEqual(fingerprint1, fingerprint2)
Exemplo n.º 17
0
    def test_multi_input(self):
        ops = [
            new_op(op_name="dense0",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu0",
                   op_type=OpType.RELU,
                   input_names=["dense0"]),
            new_op(op_name="dense1",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu1",
                   op_type=OpType.RELU,
                   input_names=["dense1"]),
            new_op(op_name="dense2",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu2",
                   op_type=OpType.RELU,
                   input_names=["dense2"]),
            new_op(op_name="add0",
                   op_type=OpType.ADD,
                   input_names=["relu0", "relu1"]),
            new_op(op_name="add1",
                   op_type=OpType.ADD,
                   input_names=["relu1", "relu2"]),
        ]
        graph = new_graph(input_names=["input"],
                          output_names=["add0", "add1"],
                          ops=ops)
        subgraph_spec = [
            SubgraphNode(op=new_op(
                op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])),
            SubgraphNode(op=new_op(
                op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])),
            SubgraphNode(op=new_op(
                op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])),
            SubgraphNode(op=new_op(op_name="add0",
                                   op_type=OpType.ADD,
                                   input_names=["relu0", "relu1"]),
                         output_names=["add0"]),
            SubgraphNode(op=new_op(op_name="add1",
                                   op_type=OpType.ADD,
                                   input_names=["relu1", "relu2"]),
                         output_names=["add1"]),
        ]
        replaced_graph = replace_subgraph(graph, subgraph_spec)
        inp = {"input": jnp.ones((10, 32, 32, 3))}
        subgraph_model = SubgraphModel(replaced_graph, {}, {}, inp,
                                       subgraph_spec)
        lp = linear.LinopProperty().infer(subgraph_model)
        pairings = lp.pairings

        self.assertLen(pairings, 2)
        self.assertIn("add0:0", pairings)
        self.assertLen(pairings["add0:0"], 2)
        self.assertIn("dense0:0", pairings["add0:0"])
        self.assertIn("dense1:0", pairings["add0:0"])
        self.assertIn("add1:0", pairings)
        self.assertLen(pairings["add1:0"], 2)
        self.assertIn("dense1:0", pairings["add1:0"])
        self.assertIn("dense2:0", pairings["add1:0"])
Exemplo n.º 18
0
    def instantiate(self, input_names, instance_id=None, constants=None):
        """Instantiates a version of the block with unique names.

    This method uses the names of graph and constants from the initial
    definition of the block (__init__) , so that one can instantiate from any
    derived block with same effect, e.g., if we have:
      init_block = block.__init__(name="conv_layer", ...)
      block0 = init_block.instantiate(instance_id=0, ...)
    then:
      block1 = init_block.instantiate(instance_id=1, ...)
    will have the same effect as:
      block1 = block0.instantiate(instance_id=1, ...)
    The one caveat is that the default values for unspecified constants are
    inherited from the instantiating block (instead of the initial definition).

    Args:
      input_names: The input tensor names the instantiated block will consume.
      instance_id: An id to make the names in the instantiated block unique.
        The id should be unique within a graph.
      constants: Updated parameters for the instantiated block.

    Returns:
      An instantiated block.

    Raises:
      ValueError: if the number of input names provided does not equal the
        number of inputs consumed by the graph.
    """
        if len(input_names) != len(self.base_graph.input_names):
            raise ValueError("Wrong number of inputs provided.")

        prefix = ""
        if self.name: prefix += self.name
        if instance_id is not None: prefix += str(instance_id)
        if prefix: prefix += "/"

        if not constants: constants = dict(self.base_constants)

        new_input_names = input_names
        updated_names = {
            o: n
            for o, n in zip(self.base_graph.input_names, new_input_names)
        }
        inputs_names = [
            canonicalize_tensor_name(n) for n in self.base_graph.input_names
        ]
        updated_names.update(
            {o: n
             for o, n in zip(inputs_names, new_input_names)})

        # Update ops.
        new_ops = []
        for op in self.base_graph.ops:
            # Update all input tensor names.
            # Any internal inputs (i.e., anything that is not a graph input) needs to
            # be updated with the prefix.
            new_inputs = []
            for inp in op.input_names:
                try:
                    idx = inputs_names.index(inp)
                    new_inputs.append(new_input_names[idx])
                except ValueError:
                    new_inputs.append(f"{prefix}{inp}")

            # Update symbolic constant names in input_kwargs and op_kwargs.
            new_kwargs = []
            for kwargs in [op.input_kwargs, op.op_kwargs]:
                nk = {
                    k: _prefix_symbolic(v, prefix, constants, updated_names)
                    for k, v in kwargs.items()
                }
                new_kwargs.append(nk)

            new_ops.append(
                new_op(op_name=f"{prefix}{op.name}",
                       op_type=op.type,
                       input_names=new_inputs,
                       input_kwargs=new_kwargs[0],
                       op_kwargs=new_kwargs[1],
                       num_outputs=op.num_outputs))

        # Update constants and prefix symbolic constant names.
        old_constants = dict(self.base_constants)
        if constants: old_constants.update(constants)
        new_constants = {f"{prefix}{k}": v for k, v in old_constants.items()}

        # Prefix graph output names.
        new_output_names = [
            f"{prefix}{on}" for on in self.base_graph.output_names
        ]

        graph = new_graph(ops=new_ops,
                          input_names=new_input_names,
                          output_names=new_output_names)
        return Block(name=self.name,
                     graph=graph,
                     constants=new_constants,
                     base_graph=self.base_graph,
                     base_constants=old_constants)
Exemplo n.º 19
0
    def op_enumerator(
        cls,
        prefix=None,
        kwarg_defaults=None,
        full=True,
        op_types=None,
    ):
        if not prefix:
            prefix = ""
        elif not prefix.endswith("/"):
            prefix = f"{prefix}/"

        kwarg_defaults = cls.make_default_kwargs(kwarg_defaults, full)

        if op_types is None:
            op_types = OpType

        for op_type in op_types:

            name = f"{prefix}{op_type.name.lower()}"
            inputs = ["inputs"]

            if op_type in [
                    OpType.IDENTITY,
                    OpType.NONE,
                    OpType.DENSE_GENERAL,
                    OpType.ADD,
                    OpType.MUL,
                    OpType.SCALAR_ADD,
                    OpType.DOT_GENERAL,
                    OpType.EINSUM,
                    OpType.FLATTEN,
                    OpType.RESHAPE,
                    OpType.TRANSPOSE,
                    OpType.PARAM,
                    OpType.SELF_ATTENTION,
                    OpType.STOCH_DEPTH,
                    OpType.MEAN,
            ]:
                # Not supported for synthesis.
                pass
            elif op_type in [
                    OpType.SCALAR_MUL,
                    OpType.BATCH_NORM,
                    OpType.LAYER_NORM,
                    OpType.RELU,
                    OpType.GELU,
                    OpType.SWISH,
                    OpType.SIGMOID,
                    OpType.SOFTMAX,
            ]:
                # No kwargs.
                yield new_op(name, op_type, inputs)
            elif op_type in [
                    OpType.DENSE,
                    OpType.CONV,
                    OpType.GROUP_NORM,
                    OpType.AVG_POOL,
                    OpType.MAX_POOL,
                    OpType.DROPOUT,
            ]:
                op_kwargs_dict, input_kwargs_dict = cls.all_kwargs_for_op_type(
                    kwarg_defaults, full, op_type)
                for op_kwargs, input_kwargs in cls.kwargs_for_op_to_product(
                        op_kwargs_dict, input_kwargs_dict):
                    yield new_op(name,
                                 op_type,
                                 inputs,
                                 op_kwargs=op_kwargs,
                                 input_kwargs=input_kwargs)
            else:
                assert False, f"op_type {op_type} not supported"
        return
Exemplo n.º 20
0
    def test_multi_input(self):
        ops = [
            new_op(op_name="dense0",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu0",
                   op_type=OpType.RELU,
                   input_names=["dense0"]),
            new_op(op_name="dense1",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu1",
                   op_type=OpType.RELU,
                   input_names=["dense1"]),
            new_op(op_name="dense2",
                   op_type=OpType.DENSE,
                   op_kwargs={"features": 32},
                   input_names=["input"]),
            new_op(op_name="relu2",
                   op_type=OpType.RELU,
                   input_names=["dense2"]),
            new_op(op_name="add0",
                   op_type=OpType.ADD,
                   input_names=["relu0", "relu1"]),
            new_op(op_name="add1",
                   op_type=OpType.ADD,
                   input_names=["relu1", "relu2"]),
        ]
        graph = new_graph(input_names=["input"],
                          output_names=["add0", "add1"],
                          ops=ops)
        subgraph_spec = [
            SubgraphNode(op=new_op(
                op_name="relu0", op_type=OpType.RELU, input_names=["dense0"])),
            SubgraphNode(op=new_op(
                op_name="relu1", op_type=OpType.RELU, input_names=["dense1"])),
            SubgraphNode(op=new_op(
                op_name="relu2", op_type=OpType.RELU, input_names=["dense2"])),
            SubgraphNode(op=new_op(op_name="add0",
                                   op_type=OpType.ADD,
                                   input_names=["relu0", "relu1"]),
                         output_names=["add0"]),
            SubgraphNode(op=new_op(op_name="add1",
                                   op_type=OpType.ADD,
                                   input_names=["relu1", "relu2"]),
                         output_names=["add1"]),
        ]
        replaced_graph = replace_subgraph(graph, subgraph_spec)
        subgraph_model = SubgraphModel(replaced_graph, {}, {}, {},
                                       subgraph_spec)
        dp = depth.DepthProperty().infer(subgraph_model)
        depth_map = dp.depth_map

        self.assertLen(depth_map, 3)
        self.assertIn("dense0:0", depth_map)
        self.assertIn("dense1:0", depth_map)
        self.assertIn("dense2:0", depth_map)
        self.assertLen(depth_map["dense0:0"], 1)
        self.assertEqual(depth_map["dense0:0"]["add0:0"], 2)
        self.assertLen(depth_map["dense1:0"], 2)
        self.assertEqual(depth_map["dense1:0"]["add0:0"], 2)
        self.assertEqual(depth_map["dense1:0"]["add1:0"], 2)
        self.assertLen(depth_map["dense2:0"], 1)
        self.assertEqual(depth_map["dense2:0"]["add1:0"], 2)
Exemplo n.º 21
0
    def resolve_op(self, op, intermediate_values, **_):
        """Resolves an op with possibly symbolic arguments to a concrete op."""
        op_name = op.name.lower()
        op_type = op.type

        input_names = op.input_names
        input_values = [
            intermediate_values[key.lower()] for key in input_names
        ]

        input_kwargs: Dict[str, Any] = op.input_kwargs
        op_kwargs: Dict[str, Any] = op.op_kwargs
        op_kwargs["name"] = op_name

        if op_type == OpType.NONE:
            pass

        elif op_type == OpType.IDENTITY:
            pass

        # nn.linear

        elif op_type == OpType.DENSE:
            _kv_resolve_symbolic(op_kwargs, ["kernel_init", "bias_init"])
            _kv_resolve_symbolic(op_kwargs, ["features"], input_values,
                                 intermediate_values)

        elif op_type == OpType.DENSE_GENERAL:
            _kv_to_int(op_kwargs, ["axis", "batch_dims"])
            _kv_resolve_symbolic(op_kwargs, ["kernel_init", "bias_init"])
            _kv_resolve_symbolic(op_kwargs, ["features"], input_values,
                                 intermediate_values)

        elif op_type == OpType.CONV:
            _kv_to_int(op_kwargs, [
                "kernel_size",
                "strides",
                "input_dilation",
                "kernel_dilation",
                "padding",
            ])
            _kv_resolve_symbolic(op_kwargs, ["kernel_init", "bias_init"])
            _kv_resolve_symbolic(op_kwargs,
                                 ["features", "feature_group_count"],
                                 input_values, intermediate_values)

        # others

        elif op_type == OpType.ADD:
            _kv_to_float(op_kwargs, ["layer_drop_rate"])

        elif op_type == OpType.SCALAR_ADD:
            _kv_to_float(input_kwargs, ["const"])

        elif op_type == OpType.MUL:
            pass

        elif op_type == OpType.SCALAR_MUL:
            _kv_to_float(input_kwargs, ["const"])

        elif op_type == OpType.DOT_GENERAL:
            _kv_to_int(input_kwargs, ["dimension_numbers"])

        elif op_type == OpType.EINSUM:
            pass

        # nn.attention

        elif op_type == OpType.SELF_ATTENTION:
            _kv_resolve_symbolic(op_kwargs, ["kernel_init", "bias_init"])
            _kv_resolve_symbolic(op_kwargs,
                                 ["num_heads", "qkv_features", "out_features"],
                                 input_values, intermediate_values)

        # nn.activation

        elif op_type in [
                OpType.RELU, OpType.GELU, OpType.SWISH, OpType.SIGMOID
        ]:
            pass

        elif op_type == OpType.SOFTMAX:
            _kv_to_int(input_kwargs, ["axis"])

        # nn.normalization

        elif op_type == OpType.BATCH_NORM:
            _kv_to_int(op_kwargs, ["axis"])
            _kv_resolve_symbolic(op_kwargs, ["scale_init", "bias_init"])

        elif op_type == OpType.LAYER_NORM:
            pass

        elif op_type == OpType.GROUP_NORM:
            _kv_resolve_symbolic(op_kwargs, ["num_groups", "group_size"],
                                 input_values, intermediate_values)

        # reshape operators

        elif op_type == OpType.RESHAPE:
            _kv_resolve_symbolic(input_kwargs, ["new_shape"], input_values,
                                 intermediate_values)
            _kv_to_int(input_kwargs, ["new_shape"])

        elif op_type == OpType.FLATTEN:
            pass

        elif op_type == OpType.TRANSPOSE:
            _kv_to_int(input_kwargs, ["axes"])

        # nn.stochastic

        elif op_type == OpType.DROPOUT:
            _kv_to_int(op_kwargs, ["broadcast_dims"])
            _kv_to_float(op_kwargs, ["rate"])

        elif op_type == OpType.STOCH_DEPTH:
            _kv_to_float(op_kwargs, ["layer_drop_rate"])

        # nn.pooling

        elif op_type == OpType.AVG_POOL:
            _kv_to_int(input_kwargs, ["window_shape", "strides"])

        elif op_type == OpType.MAX_POOL:
            _kv_to_int(input_kwargs, ["window_shape", "strides"])

        elif op_type == OpType.MEAN:
            _kv_to_int(input_kwargs, ["axis"])

        # new param

        elif op_type == OpType.PARAM:
            _kv_to_int(input_kwargs, ["shape"])
            _kv_resolve_symbolic(input_kwargs, ["shape", "init_fn"],
                                 input_values, intermediate_values)

        else:
            raise ValueError(f"op_type {op_type} not supported...")

        return new_op(op_name,
                      op_type,
                      input_names,
                      input_kwargs,
                      op_kwargs,
                      num_outputs=op.num_outputs)