def setUp(self):
        super().setUp()

        seed = int(time.time())
        logging.info("Seed: %d", seed)
        py_rand.seed(seed)

        self.graph, self.constants, _ = cnn.CifarNet()
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state, self.input)["fc/logits"]
        self.max_size = int(10e8)

        self.hard = False
Esempio n. 2
0
    def test_rewire(self):
        subgraph_spec = [
            SubgraphNode(op=new_op(op_name="conv_layer1/conv/1",
                                   op_type=OpType.CONV,
                                   op_kwargs={
                                       "features": 64,
                                       "kernel_size": [1, 1]
                                   },
                                   input_names=["conv_layer0/avg_pool"]), ),
            SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1",
                                   op_type=OpType.GELU,
                                   input_names=["conv_layer1/conv/1"]),
                         output_names=["conv_layer1/relu"])
        ]

        graph = replace_subgraph(self.graph, subgraph_spec)
        state = Model(graph,
                      self.constants).init(random.PRNGKey(0),
                                           {"input": jnp.ones((5, 32, 32, 3))})
        subgraph_model = SubgraphModel(graph, self.constants, state,
                                       {"input": jnp.ones(
                                           (5, 32, 32, 3))}, subgraph_spec)
        sp = shape.ShapeProperty().infer(subgraph_model)

        self.assertLen(sp.input_shapes, 1)
        self.assertIn("conv_layer0/avg_pool:0", sp.input_shapes)
        self.assertLen(sp.output_shapes, 2)
        self.assertIn("conv_layer1/gelu/1:0", sp.output_shapes)
        self.assertIn("conv_layer1/relu:0", sp.output_shapes)
Esempio n. 3
0
def fingerprint_graph(graph, constants, input_values, state=None):
    """Returns a fingerprint for functional equivalence."""

    # Get shape info for resolving ops.
    shapes = GraphShapes.infer(Model(graph, constants),
                               input_values=input_values,
                               state=state,
                               intermediates=True,
                               abstract=True)

    # Save original output names.
    output_names = graph.output_names

    # Augment constants with shapes.
    constants = dict(constants) if constants else {}
    for input_name in graph.input_names:
        input_name = canonicalize_tensor_name(input_name)
        constants[f"hashed/{input_name}"] = hash(input_name)
    constants.update(shapes.input_shapes)
    constants.update(shapes.output_shapes)

    # Output hash values.
    graph.output_names = [
        f"hashed/{output_name}" for output_name in output_names
    ]

    # Get fingerprints for each output tensor.
    fingerprints = _FingerprintModel(graph, constants).apply({}, input_values)

    # Restore original output names.
    graph.output_names = output_names

    # Return hash of outputs.
    return hash(frozenset(fingerprints.items()))
    def _synthesize(self, subg, props):
        synthesizer = ProgressiveSequentialSynthesizer(
            [(subg, props)],
            generation=0,
            mode=ProgressiveSequentialSynthesizer.Mode.WEIGHTED,
            max_len=3)

        subg = synthesizer.synthesize()[0]
        subg_spec = subg.subgraph
        for node in subg_spec:
            print(node.op.name)
            print(node.output_names)

        m = Model(subg.graph, self.constants)
        state = m.init(random.PRNGKey(0), self.input)
        out = m.apply(state, self.input)["fc/logits"]
        self.assertTrue((out != self.out).any())
Esempio n. 5
0
  def __init__(self,
               graph,
               constants,
               state,
               inputs,
               subgraph = None):
    self.graph = graph
    self.constants = constants
    self.state = state
    self.inputs = inputs
    self.subgraph: SubgraphSpec = subgraph if subgraph else []

    self.input_names = None
    self.output_names = None
    self.original_outputs = graph.output_names

    if subgraph:
      self._subgraph_to_names()

      # graph for graph inputs -> subg inputs
      self.subg_inputs_graph = copy.deepcopy(graph)
      self.subg_inputs_graph.output_names = self.input_names
      self.subg_inputs_model = Model(self.subg_inputs_graph, self.constants)
      self.subg_inputs = None

      # graph for graph inputs -> subg outputs
      self.subg_outputs_graph = copy.deepcopy(graph)
      self.subg_outputs_graph.output_names = self.output_names
      self.subg_outputs_model = Model(self.subg_outputs_graph, self.constants)
      self.subg_outputs = None

      # graph for subg inputs -> subg outputs
      subg_ops = [node.op for node in subgraph]
      self.subg_graph = new_graph(self.input_names, self.output_names, subg_ops)
      self.subg_model = Model(self.subg_graph, self.constants)
    else:
      self.input_names = [
          canonicalize_tensor_name(name) for name in graph.input_names
      ]
      self.output_names = [
          canonicalize_tensor_name(name) for name in graph.output_names
      ]

      # subg inputs = inputs to the graph
      self.subg_inputs_graph = None
      self.subg_inputs_model = None
      self.subg_inputs = inputs

      # graph for graph inputs -> subg outputs
      self.subg_outputs_graph = copy.deepcopy(graph)
      self.subg_outputs_model = Model(self.subg_outputs_graph, self.constants)
      self.subg_outputs = None

      # subg outputs = full graph outputs
      self.subg_graph = self.subg_outputs_graph
      self.subg_model = self.subg_outputs_model
Esempio n. 6
0
    def test_inference(self):
        graph, constants, _ = resnetv1.ResNet18(num_classes=10,
                                                input_resolution="small")
        model = Model(graph, constants)
        state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
        self.assertLen(state, 2)
        self.assertIn("params", state)
        self.assertIn("batch_stats", state)

        out = model.apply(state, {"input": jnp.ones(
            (10, 32, 32, 3))})["fc/dense"]
        self.assertEqual(out.shape, (10, 10))

        output_dict, new_state = model.apply(
            state, {"input": jnp.ones((10, 32, 32, 3))},
            mutable=["batch_stats"])
        self.assertEqual(output_dict["fc/dense"].shape, (10, 10))
        self.assertIn("batch_stats", new_state)
Esempio n. 7
0
    def test_abstract(self):
        graphs = []

        conv_op = functools.partial(new_op,
                                    op_type=OpType.CONV,
                                    op_kwargs={
                                        "features": 10,
                                        "kernel_size": [3, 3]
                                    })
        dense_op = functools.partial(new_op,
                                     op_type=OpType.DENSE,
                                     op_kwargs={
                                         "features": 10,
                                     })

        for op_type in [
                OpType.RELU, OpType.SOFTMAX, OpType.LAYER_NORM,
                OpType.BATCH_NORM
        ]:
            for op_ctr in [conv_op, dense_op]:
                graphs.append([
                    op_ctr(input_names=["input"], op_name="other"),
                    new_op(op_name="output",
                           op_type=op_type,
                           input_names=["other"])
                ])
                graphs.append([
                    new_op(op_name="other",
                           op_type=op_type,
                           input_names=["input"]),
                    op_ctr(input_names=["other"], op_name="output"),
                ])

        input_tensor = {"input": jnp.ones((5, 5, 5, 5))}
        for graph in graphs:
            graph = new_graph(["input"], ["output"], graph)
            state = Model(graph).init(random.PRNGKey(1), input_tensor)

            # Make all the kernels positive, otherwise, the ReLU might zero out the
            # entire tensor.
            state = jax.tree_util.tree_map(abs, state)

            subg_model = SubgraphModel(graph, None, state, input_tensor)
            lp_abstract = linear.LinopProperty().infer(subg_model,
                                                       abstract=True)
            lp_concrete = linear.LinopProperty().infer(subg_model,
                                                       abstract=False)
            pairings_concerete = lp_concrete.pairings["output"][
                "input"].mappings
            pairings_abstract = lp_abstract.pairings["output"][
                "input"].mappings

            print("concrete:", pairings_concerete)
            print("abstract:", pairings_abstract)
            self.assertTrue(
                ((pairings_abstract - pairings_concerete) == 0).all())
Esempio n. 8
0
    def test_synthesizer_resnet_big(self):
        self.graph, self.constants, _ = resnetv1.ResNet18(
            num_classes=10, input_resolution="small")
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state,
                                self.input)[self.graph.output_names[0]]
        self.max_size = int(10e8)

        subg_ops = self.graph.ops[3:5] + self.graph.ops[8:12]
        subg = [subgraph.SubgraphNode(op=o) for o in subg_ops]
        subg[-1].output_names = [f"{subg[-1].op.name}:0"]
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])
Esempio n. 9
0
    def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        state = Model(self.graph,
                      self.constants).init(random.PRNGKey(0),
                                           {"input": jnp.ones((5, 32, 32, 3))})
        self.subgraph_model = SubgraphModel(
            self.graph, self.constants, state,
            {"input": jnp.ones((5, 32, 32, 3))})
        self.sp = shape.ShapeProperty().infer(self.subgraph_model)
Esempio n. 10
0
 def test_unsatisfy(self):
     # This test removes the last dense layer, so the new graph should have a
     # different shape (and therefore not satisfy the inferred shape property).
     graph = new_graph(input_names=["input"],
                       output_names=["fc/relu"],
                       ops=self.graph.ops)
     state = Model(graph,
                   self.constants).init(random.PRNGKey(0),
                                        {"input": jnp.ones((5, 32, 32, 3))})
     subgraph_model = SubgraphModel(graph, self.constants, state,
                                    {"input": jnp.ones((5, 32, 32, 3))})
     self.assertFalse(self.sp.verify(subgraph_model))
Esempio n. 11
0
    def test_inference(self):
        graph, constants, _ = efficientnet.EfficientNetB0(num_classes=10)
        for op in graph.ops:
            print(f"name={op.name}")
            print(f"input_names={op.input_names}")
        print()
        model = Model(graph, constants)
        state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
        self.assertLen(state, 2)
        self.assertIn("params", state)
        self.assertIn("batch_stats", state)

        inp = {"input": jnp.ones((10, 32, 32, 3))}

        out = model.apply(state, inp)["head/out"]
        self.assertEqual(out.shape, (10, 10))

        output_dict, new_state = model.apply(state,
                                             inp,
                                             mutable=["batch_stats"])
        self.assertEqual(output_dict["head/out"].shape, (10, 10))
        self.assertIn("batch_stats", new_state)
Esempio n. 12
0
class ModelTest(test.TestCase):
    def setUp(self):
        super().setUp()

        graph, constants, _ = cnn.CifarNet()
        self.cnn = Model(graph, constants)
        self.cnn_state = self.cnn.init(random.PRNGKey(0),
                                       jnp.ones((1, 32, 32, 3)))

    def test_cnn_inference(self):
        y = self.cnn.apply(self.cnn_state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))

    def test_cnn_inference_dict(self):
        out = self.cnn.apply(self.cnn_state,
                             {"input": jnp.ones((10, 32, 32, 3))})
        logits = out["fc/logits"]
        self.assertEqual(logits.shape, (10, 10))

    def test_cnn_params(self):
        params = flax.core.unfreeze(self.cnn_state)["params"]
        param_count = parameter_overview.count_parameters(params)
        self.assertEqual(param_count, 2192458)
Esempio n. 13
0
    def loss_fn(params):
      all_params = {**params, **other_params}
      logits, new_coll = Model(graph, constants).apply(
          flax.core.freeze({
              "params": all_params,
              **coll
          }),
          data,
          rngs={"dropout": rng_model_local},
          mutable=list(coll.keys()),
          deterministic=False,
          training=True)

      loss = jnp.mean(
          bv_utils.softmax_xent(
              logits=logits, labels=labels))
      return loss, (logits, loss, new_coll)
Esempio n. 14
0
    def test_multi_input_output(self):
        """Tests a subgraph substitution on a graph with multiple inputs / output ops.

    We use a ResNet model, which has skip connections. This test checks that the
    substitution produces the expected number of ops, and also that the newly
    produced graph is still executable.
    """

        graph, constants, _ = resnetv1.ResNet18(num_classes=10,
                                                input_resolution="small")
        model = Model(graph, constants)
        state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
        y = model.apply(state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))

        subg = [
            subgraph.SubgraphNode(
                op=new_op(op_name="subgraph/conv0",
                          op_type=OpType.CONV,
                          op_kwargs={
                              "features": 64,
                              "kernel_size": [1, 1]
                          },
                          input_names=["resnet11/skip/relu1"])),
            subgraph.SubgraphNode(
                op=new_op(op_name="subgraph/gelu1",
                          op_type=OpType.GELU,
                          input_names=["subgraph/conv0"]),
                output_names=["resnet_stride1_filtermul1_basic12/relu2"])
        ]
        new_graph = subgraph.replace_subgraph(graph, subg)

        # the subgraph is 2 ops (conv / gelu) replacing 3 ops (conv / bn / relu)
        self.assertLen(graph.ops, len(new_graph.ops) + 1)

        new_model = Model(new_graph, constants)
        new_state = new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))

        y = new_model.apply(new_state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))
Esempio n. 15
0
  def eval_step(params, coll, data, labels, mask):
    mask *= labels.max(axis=1)
    logits = Model(graph, constants).apply(
        flax.core.freeze({
            "params": params,
            **coll
        }),
        data,
        deterministic=True,
        training=False)
    loss = jnp.mean(
        bv_utils.softmax_xent(
            logits=logits, labels=labels))

    top1_idx = jnp.argmax(logits, axis=1)
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]
    correct = top1_correct * mask
    return (jax.lax.psum(correct, axis_name="batch"),
            jax.lax.psum(loss, axis_name="batch"),
            jax.lax.psum(mask, axis_name="batch"))
Esempio n. 16
0
class SubgraphModel():
  """A concrete subgraph.

  A concrete subgraph consists of:
  - The full graph, in which the subgraph is *already* embedded, i.e., you
  should call replace_subgraph BEFORE creating a ConcreteSubgraph!
  - An instantiation of the full graph, as specified by the state (i.e.,
  parameters). If None, the subgraph is treated as abstract.
  - An execution of the full graph, as specified by a set of inputs.
  - A specification of the subgraph, as defined by the list of subgraph nodes.
  If None, the subgraph is just the full graph.
  """

  def __init__(self,
               graph,
               constants,
               state,
               inputs,
               subgraph = None):
    self.graph = graph
    self.constants = constants
    self.state = state
    self.inputs = inputs
    self.subgraph: SubgraphSpec = subgraph if subgraph else []

    self.input_names = None
    self.output_names = None
    self.original_outputs = graph.output_names

    if subgraph:
      self._subgraph_to_names()

      # graph for graph inputs -> subg inputs
      self.subg_inputs_graph = copy.deepcopy(graph)
      self.subg_inputs_graph.output_names = self.input_names
      self.subg_inputs_model = Model(self.subg_inputs_graph, self.constants)
      self.subg_inputs = None

      # graph for graph inputs -> subg outputs
      self.subg_outputs_graph = copy.deepcopy(graph)
      self.subg_outputs_graph.output_names = self.output_names
      self.subg_outputs_model = Model(self.subg_outputs_graph, self.constants)
      self.subg_outputs = None

      # graph for subg inputs -> subg outputs
      subg_ops = [node.op for node in subgraph]
      self.subg_graph = new_graph(self.input_names, self.output_names, subg_ops)
      self.subg_model = Model(self.subg_graph, self.constants)
    else:
      self.input_names = [
          canonicalize_tensor_name(name) for name in graph.input_names
      ]
      self.output_names = [
          canonicalize_tensor_name(name) for name in graph.output_names
      ]

      # subg inputs = inputs to the graph
      self.subg_inputs_graph = None
      self.subg_inputs_model = None
      self.subg_inputs = inputs

      # graph for graph inputs -> subg outputs
      self.subg_outputs_graph = copy.deepcopy(graph)
      self.subg_outputs_model = Model(self.subg_outputs_graph, self.constants)
      self.subg_outputs = None

      # subg outputs = full graph outputs
      self.subg_graph = self.subg_outputs_graph
      self.subg_model = self.subg_outputs_model

  def _subgraph_to_names(self):
    """Populates the incoming and outgoing edges of the subgraph."""
    assert self.subgraph

    input_names = []
    output_names = []
    produced = []
    for node in self.subgraph:
      # check to see which inputs are incoming edges to the subgraph
      for input_name in node.op.input_names:
        if input_name not in produced and input_name not in input_names:
          input_names.append(input_name)

      # keep track of produced tensors (internal edges in the subgraph)
      for idx in range(node.op.num_outputs):
        produced.append(f"{node.op.name}:{idx}")

      # only the rewired outputs become externally visible to the graph
      for idx, output_name in enumerate(node.output_names):
        if output_name is not None:
          output_names.append(f"{node.op.name}:{idx}")

    self.input_names = input_names
    self.output_names = output_names

  def get_subg_inputs(
      self, graph_inputs,
      intermediates = False,
  ):
    """Returns the inputs to the subgraph given inputs to the full graph.

    Args:
      graph_inputs: The dictionary of input values to the full graph.
      intermediates: Whether to return all the inputs.

    Returns:
      The inputs to the subgraph.

    Raises:
      ValueError: If execution is necessary, but state is not provided.
    """

    # if no self.subg_inputs_model, then the subgraph is the full graph, so the
    # input to the subgraph is the same as the input to the full graph
    if not self.subg_inputs_model:
      return graph_inputs

    # execute the subg_inputs_model
    if not self.state:
      raise ValueError("Cannot execute subgraph without state.")
    if intermediates:
      old_output_names = self.subg_inputs_model.graph.output_names
      self.subg_inputs_model.graph.output_names = []
    subg_inputs = self.subg_inputs_model.apply(self.state, graph_inputs)
    if intermediates:
      self.subg_inputs_model.graph.output_names = old_output_names

    return subg_inputs

  def get_default_subg_inputs(self):
    """Returns the default inputs to the subgraph."""
    if self.subg_inputs is not None:
      return self.subg_inputs
    self.subg_inputs = self.get_subg_inputs(self.inputs)
    return self.subg_inputs

  def get_subg_outputs(
      self, graph_inputs
  ):
    """Returns the output from the subgraph given inputs to the full graph.

    Args:
      graph_inputs: The dictionary of input values to the full graph. If None,
        defaults to the stored input values.

    Returns:
      The outputs of the subgraph.

    Raises:
      ValueError: If execution is necessary, but state is not provided.
    """
    # execute the subg_outputs_model
    if not self.state:
      raise ValueError("Cannot execute subgraph without state.")
    return self.subg_outputs_model.apply(self.state, graph_inputs)

  def get_default_subg_outputs(self):
    """Returns the default outputs of the subgraph."""
    if self.subg_outputs is not None:
      return self.subg_outputs
    subg_inputs = self.get_default_subg_inputs()
    self.subg_outputs = self.execute_subg(subg_inputs)
    return self.subg_outputs

  def execute_subg(
      self, inputs
  ):
    """Returns the output from the subgraph given inputs to the subgraph.

    Args:
      inputs: The dictionary of input values to the subgraph.

    Returns:
      The outputs of the subgraph.

    Raises:
      ValueError: If state is not provided.
    """
    if not self.state:
      raise ValueError("Cannot execute subgraph without state.")
    return self.subg_model.apply(self.state, inputs)

  def update_subg_outputs(self, output_names):
    """Updates the outputs of the subgraph.

    Args:
      output_names: The list of new output_names.

    Raises:
      ValueError: If output_names are not produced in the subgraph.
    """

    for output_name in output_names:
      found = False
      for op in self.subg_graph.ops:
        for idx in range(op.num_outputs):
          if output_name == f"{op.name}:{idx}":
            found = True
            break
        if found: break
      if not found:
        raise ValueError(f"Requested output {output_name} not in subgraph.")

    self.output_names = output_names
    self.subg_graph.output_names = output_names
    self.subg_model.graph.output_names = output_names
    self.subg_outputs_graph.output_names = output_names
    self.subg_outputs_model.graph.output_names = output_names
Esempio n. 17
0
class SubgraphTest(test.TestCase):
    def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        self.model = Model(self.graph, self.constants)
        self.state = self.model.init(random.PRNGKey(0), jnp.ones(
            (1, 32, 32, 3)))

        self.subgraph = [
            subgraph.SubgraphNode(op=new_op(
                op_name="conv_layer1/conv/1",
                op_type=OpType.CONV,
                op_kwargs={
                    "features": 64,
                    "kernel_size": [1, 1]
                },
                input_names=["conv_layer0/avg_pool"]), ),
            subgraph.SubgraphNode(op=new_op(op_name="conv_layer1/gelu/1",
                                            op_type=OpType.GELU,
                                            input_names=["conv_layer1/conv/1"
                                                         ]),
                                  output_names=["conv_layer1/relu"])
        ]
        self.new_graph = subgraph.replace_subgraph(self.graph, self.subgraph)
        self.new_model = Model(self.new_graph, self.constants)
        self.new_state = self.new_model.init(random.PRNGKey(0),
                                             jnp.ones((1, 32, 32, 3)))

    def test_subgraph_inserted(self):
        """Tests whether subgraph nodes were inserted."""
        for node in self.subgraph:
            found = False
            for op in self.new_graph.ops:
                if op.name == node.op.name:
                    found = True
                    break
            self.assertTrue(found, f"Did not find {node.op.name} in new graph")

    def test_subgraph_execution(self):
        """Tests whether new graph can be executed."""
        y = self.new_model.apply(self.new_state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))

    def test_subgraph_pruning(self):
        """Tests whether new graph was pruned of old nodes."""
        new_params = flax.core.unfreeze(self.new_state)["params"]
        new_param_count = parameter_overview.count_parameters(new_params)

        params = flax.core.unfreeze(self.state)["params"]
        param_count = parameter_overview.count_parameters(params)
        self.assertLess(new_param_count, param_count)

    def test_weight_inheritance(self):
        """Tests weight inheritance."""
        old_params = flax.core.unfreeze(self.state)["params"]
        new_params = flax.core.unfreeze(self.new_state)["params"]

        frozen_params, trainable_params = subgraph.inherit_params(
            new_params, old_params)

        self.assertLen(new_params, len(trainable_params) + len(frozen_params))

        for param in ["fc/dense", "fc/logits", "conv_layer0/conv"]:
            assert param in frozen_params, f"expected param {param} to be frozen"
        self.assertIn("conv_layer1/conv/1", trainable_params,
                      ("expected param layer1/conv/1 to be trainable"))

    def test_multi_input_output(self):
        """Tests a subgraph substitution on a graph with multiple inputs / output ops.

    We use a ResNet model, which has skip connections. This test checks that the
    substitution produces the expected number of ops, and also that the newly
    produced graph is still executable.
    """

        graph, constants, _ = resnetv1.ResNet18(num_classes=10,
                                                input_resolution="small")
        model = Model(graph, constants)
        state = model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))
        y = model.apply(state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))

        subg = [
            subgraph.SubgraphNode(
                op=new_op(op_name="subgraph/conv0",
                          op_type=OpType.CONV,
                          op_kwargs={
                              "features": 64,
                              "kernel_size": [1, 1]
                          },
                          input_names=["resnet11/skip/relu1"])),
            subgraph.SubgraphNode(
                op=new_op(op_name="subgraph/gelu1",
                          op_type=OpType.GELU,
                          input_names=["subgraph/conv0"]),
                output_names=["resnet_stride1_filtermul1_basic12/relu2"])
        ]
        new_graph = subgraph.replace_subgraph(graph, subg)

        # the subgraph is 2 ops (conv / gelu) replacing 3 ops (conv / bn / relu)
        self.assertLen(graph.ops, len(new_graph.ops) + 1)

        new_model = Model(new_graph, constants)
        new_state = new_model.init(random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))

        y = new_model.apply(new_state, jnp.ones((10, 32, 32, 3)))
        self.assertEqual(y.shape, (10, 10))
Esempio n. 18
0
class GraphSynthesizerTest(test.TestCase):
    def setUp(self):
        super().setUp()

        seed = int(time.time())
        logging.info("Seed: %d", seed)
        py_rand.seed(seed)

        self.graph, self.constants, _ = cnn.CifarNet()
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state,
                                self.input)[self.graph.output_names[0]]
        self.max_size = int(10e8)

    def _synthesize(self, subg, props):
        ctr = functools.partial(PSS,
                                generation=0,
                                max_len=3,
                                filter_progress=True)
        synthesizer = GraphSynthesizer([(subg, props)],
                                       sequential_ctr=ctr,
                                       generation=0)

        subg = synthesizer.synthesize()[0]
        subg_spec = subg.subgraph
        logging.info("synthesized...")
        for node in subg_spec:
            logging.info("%s", node.op.name)
            logging.info("%s", node.output_names)
        logging.info("")

        fingerprint_orig = fingerprint_graph(self.graph, self.constants,
                                             self.input)
        fingerprint_new = fingerprint_graph(subg.graph, self.constants,
                                            self.input)
        self.assertNotEqual(fingerprint_orig, fingerprint_new)

    def test_synthesizer_easy_one(self):
        """Replacing [conv3x3(features = 64)]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp])

    def test_synthesizer_easy_two(self):
        """Replacing [conv3x3(features = 64)]."""
        py_rand.seed(10)
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])

    def test_synthesizer_one(self):
        """Replacing [conv3x3(features = 64), ReLU]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:6]]
        subg[-1].output_names = self.graph.ops[6].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp])

    def test_synthesizer_two(self):
        """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])

    def test_synthesizer_resnet_small(self):
        self.graph, self.constants, _ = resnetv1.ResNet18(
            num_classes=10, input_resolution="small")
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state,
                                self.input)[self.graph.output_names[0]]
        self.max_size = int(10e8)

        subg_ops = [self.graph.ops[4]] + self.graph.ops[9:11]
        subg = [subgraph.SubgraphNode(op=o) for o in subg_ops]
        subg[-1].output_names = [f"{subg[-1].op.name}:0"]
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])

    def test_synthesizer_resnet_big(self):
        self.graph, self.constants, _ = resnetv1.ResNet18(
            num_classes=10, input_resolution="small")
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state,
                                self.input)[self.graph.output_names[0]]
        self.max_size = int(10e8)

        subg_ops = self.graph.ops[3:5] + self.graph.ops[8:12]
        subg = [subgraph.SubgraphNode(op=o) for o in subg_ops]
        subg[-1].output_names = [f"{subg[-1].op.name}:0"]
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])
Esempio n. 19
0
  def mutate(self, parent_graph, parent_constants,
             parent_blocks,
             model_fn, child_id):
    """Mutates the parent architecture."""

    # Mutate entire graph.
    if not self.mutate_by_block:
      parent_block = Block(f"parent{child_id}", parent_graph,
                           parent_constants)
      parent_blocks = [parent_block]

    # Mutate a block.
    new_graph = None
    new_constants = None
    new_blocks = None
    subgraph_model = None

    if self.mutate_by_block:
      mutation_type_p = random.random()
      logging.info("Mutation type p %.3f", mutation_type_p)
    else:
      # We mutate the entire graph, so no sense in adding / deleting blocks.
      # This just needs to be >= 1.0 to be safe.
      mutation_type_p = 10.

    if mutation_type_p < self.block_add_prob:
      block_id = random.randint(0, len(parent_blocks) - 1)
      logging.info("Duplicate block_id: %d.", block_id)
      new_blocks = copy.deepcopy(parent_blocks)
      new_blocks = (
          new_blocks[:block_id] + [new_blocks[block_id]] +
          new_blocks[block_id:])
      new_graph, new_constants, new_blocks = model_fn(blocks=new_blocks)
    elif (mutation_type_p < (self.block_delete_prob +
                             self.block_add_prob) and
          len(parent_blocks)) > 1:
      block_id = random.randint(0, len(parent_blocks) - 1)
      logging.info("Delete block_id: %d.", block_id)
      new_blocks = copy.deepcopy(parent_blocks)
      new_blocks = new_blocks[:block_id] + new_blocks[block_id + 1:]
      new_graph, new_constants, new_blocks = model_fn(blocks=new_blocks)
    else:
      block_id = random.randint(0, len(parent_blocks) - 1)
      logging.info("Mutate block_id: %d.", block_id)
      parent_fingerprint = fingerprint_graph(parent_graph, parent_constants,
                                             self.inp)

      # Get the block inputs for mutations
      block_input_names = [block.graph.input_names for block in parent_blocks]
      output_names = parent_graph.output_names
      parent_graph.output_names = [
          bi for bis in block_input_names for bi in bis  # pylint: disable=g-complex-comprehension
      ]
      model = Model(parent_graph, parent_constants)
      output, _ = model.init_with_output(jax.random.PRNGKey(0), self.inp)
      block_inputs = [
          {bi: output[bi] for bi in bis} for bis in block_input_names
      ]
      parent_graph.output_names = output_names

      block_inps = []
      for idx, block in enumerate(parent_blocks):
        block_inps_idx = {}
        for cur_key, old_key in zip(block.graph.input_names,
                                    block.base_graph.input_names):
          block_inps_idx[old_key] = block_inputs[idx][cur_key]
        block_inps.append(block_inps_idx)

      block_to_mutate = parent_blocks[block_id]
      blocks = [(block, block_inp)
                for (block, block_inp) in zip(parent_blocks, block_inps)
                if block.name == block_to_mutate.name]

      properties = []
      if "shape_property" in self.properties:
        properties.append(ShapeProperty(**self.properties.shape_property))
      if "depth_property" in self.properties:
        properties.append(DepthProperty(**self.properties.depth_property))
      if "linear_property" in self.properties:
        properties.append(LinopProperty(**self.properties.linear_property))
      mutator = self.mutator(properties)

      for attempt_idx in range(self.synthesis_retries):
        logging.info("Begin mutation attempt %d.", attempt_idx)

        contexts = []
        for block, block_inp in blocks:
          contexts.append((block.base_constants, None, block_inp))

        subg_and_props = mutator.mutate(
            block_to_mutate.base_graph, contexts, abstract=True)
        synthesizer = self.synthesizer(child_id, subg_and_props)

        try:
          subgraph_model = synthesizer.synthesize()[0]
        except (StopIteration, ValueError):
          logging.info("Mutation attempt %d failed: max_len reached.",
                       attempt_idx)
          new_graph = None
          continue

        logging.info("Synthesized:")
        for node in subgraph_model.subgraph:
          logging.info("  %s", node.op.name)
        logging.info("============")

        if self.mutate_by_block:
          new_blocks = self.mutate_block(block_to_mutate, parent_blocks,
                                         subgraph_model.subgraph, child_id)

          new_graph, new_constants, new_blocks = model_fn(blocks=new_blocks)
        else:
          new_graph = subgraph_model.graph
          new_constants = subgraph_model.constants
          new_block = Block(f"model{child_id}", new_graph, new_constants)
          new_blocks = [new_block]

        try:
          child_fingerprint = fingerprint_graph(copy.deepcopy(new_graph),
                                                new_constants,
                                                self.inp)
        except Exception as e:  # pylint: disable=broad-except
          logging.info("Mutation attempt %d failed: model fails to execute with"
                       "error (%s)", attempt_idx, e)
          new_graph = None
          continue

        if child_fingerprint == parent_fingerprint:
          logging.info("Mutation attempt %d failed: child fingerprint "
                       "identical to parent", attempt_idx)
          new_graph = None
          continue
        else:
          logging.info("Mutation attempt %d succeeded.", attempt_idx)
          break

    if not new_graph:
      new_constants = None
      new_blocks = None
      subgraph_model = None
      logging.info("Attempt to create child %d failed.", child_id)

    return new_graph, new_constants, new_blocks, subgraph_model
Esempio n. 20
0
def train_and_eval(
    config,
    eval_perf = True,
    callback = None
):
  """Training loop + eval.

  Args:
    config: The training config.
    eval_perf: Whether to collect performance metrics.
    callback: A callback which accepts the current epoch number and performance
      metrics, and returns True to continue training or False to early stop.

  Returns:
    A tuple of the final metrics, number of epochs trained, and the state dict.
  """
  if "train" in config.config_dict:
    train_config = config.config_dict.train
  else:
    train_config = config.config_dict
  rng = jax.random.PRNGKey(train_config.seed)

  is_host = jax.process_index() == 0

  if is_host:
    # The pool is used to perform operations such as checkpointing in async way.
    pool = multiprocessing.pool.ThreadPool(2)
  else:
    pool = None

  # set up output directory
  if config.output_dir is not None:
    if is_host:
      if gfile.exists(config.output_dir):
        logging.warn("Output directory %s already exists.", config.output_dir)
      else:
        gfile.makedirs(config.output_dir)
      utils.write_to_store(config, f"{config.output_dir}/config")
    else:
      ready = False
      for _ in range(GFILE_TRIES):
        ready = gfile.exists(config.output_dir)
        if ready: break
        time.sleep(GFILE_SLEEP_SEC)
      if not ready:
        raise ValueError(f"Output directory {config.output_dir} was not "
                         f"created within {GFILE_SLEEP_SEC * GFILE_TRIES} "
                         "secs.")

  # get data
  num_devices = jax.device_count()
  batch_size = train_config.device_batch_size * num_devices
  if batch_size % num_devices != 0:
    raise ValueError("JAX num_devices {num_devices} does not divide batch_size "
                     f"{batch_size}.")
  local_batch_size = batch_size // jax.process_count()
  local_batch_size_eval = local_batch_size * 8

  if is_host:
    logging.info(
        "Global batch size %d on %d hosts results in %d local batch size. "
        "With %d dev per host (%d dev total), that's a %d per-device batch "
        "size.",
        batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

  train_pp = preprocess.get_preprocess_fn(
      train_config.dataset_name, train_config.dataset.train_split,
      **train_config.dataset.get("preprocess_kwargs", {}))
  train_ds = input_pipeline.make_for_train(
      dataset=train_config.dataset_name,
      split=train_config.dataset.train_split,
      preprocess_fn=train_pp,
      batch_size=local_batch_size,
      shuffle_buffer_size=250_000,
      prefetch=2,
      cache_raw=False)

  train_iter = input_pipeline.start_input_pipeline(
      train_ds, n_prefetch=1)

  ntrain_img = input_pipeline.get_num_examples(train_config.dataset_name,
                                               train_config.dataset.train_split)
  steps_per_epoch = ntrain_img / batch_size
  total_steps = int(steps_per_epoch * train_config.epochs)

  eval_pp = preprocess.get_preprocess_fn(train_config.dataset_name,
                                         train_config.dataset.val_split)
  eval_ds, eval_steps = input_pipeline.make_for_inference(
      dataset=train_config.dataset_name,
      split=train_config.dataset.val_split,
      preprocess_fn=eval_pp,
      batch_size=local_batch_size_eval,
      cache_final=True,
      cache_raw=False,
      data_dir=None)
  eval_it = input_pipeline.start_input_pipeline(eval_ds, n_prefetch=1)

  # set up model
  graph = config.graph
  if isinstance(graph, tuple):
    graph, constants = graph[0], graph[1]
  else:
    constants = None
  if config.subgraph is not None:
    graph = replace_subgraph(graph, config.subgraph)
    if (config.inherit_weights and
        config.freeze_inherited and
        config.train_subg_outputs):
      # TODO(charlesjin) finish training with weight inheritance
      output_names = sum([node.output_names for node in config.subgraph], [])
      graph.output_names = output_names
      raise NotImplementedError
  model = Model(graph, constants)

  # We want all parameters to be created in host RAM, not on any device, they'll
  # be sent there later as needed, otherwise we already encountered two
  # situations where we allocate them twice.
  @partial(jax.jit, backend="cpu")
  def init(rng):
    image_size = tuple(train_ds.element_spec["image"].shape[1:])
    dummy_input = jnp.zeros((1,) + image_size, jnp.float32)
    return flax.core.unfreeze(model.init(rng, dummy_input))

  rng, rng_init = jax.random.split(rng)
  state_cpu = init(rng_init)
  params_cpu = state_cpu["params"]
  if "batch_stats" in state_cpu:
    # Non-param variable collections. Currently we only support the additional
    # collection batch_stats, which is Flax's convention for batchnorm.
    coll_cpu = {"batch_stats": state_cpu["batch_stats"]}
  else:
    coll_cpu = {}

  # weight inheritance
  if config.inherit_weights:
    if config.init_dir is None:
      raise ValueError("Cannot inherit weights without parent directory.")

    parent_state = bv_utils.load_checkpoint(None, f"{config.init_dir}/state")
    parent_params = parent_state["params"]
    old_params, new_params = inherit_params(params_cpu, parent_params)

    if config.freeze_inherited:
      trainable_params = new_params
      frozen_params = old_params
    else:
      trainable_params = {**old_params, **new_params}
      frozen_params = {}
  else:
    trainable_params = params_cpu
    frozen_params = {}

  if is_host:
    if trainable_params:
      logging.info("trainable params:")
      for key in trainable_params.keys():
        logging.info("  %s", key)
    else:
      logging.warn("WARNING: no trainable params!")
    if frozen_params:
      logging.info("frozen params:")
      for key in frozen_params.keys():
        logging.info("  %s", key)

  # training step
  @partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1, 3,))
  def train_step(opt, params, other_params, coll, data, labels, rng):
    """Trains for a single step."""
    # Get device-specific loss rng.
    rng, rng_model = jax.random.split(rng, 2)
    rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch"))
    def loss_fn(params):
      all_params = {**params, **other_params}
      logits, new_coll = Model(graph, constants).apply(
          flax.core.freeze({
              "params": all_params,
              **coll
          }),
          data,
          rngs={"dropout": rng_model_local},
          mutable=list(coll.keys()),
          deterministic=False,
          training=True)

      loss = jnp.mean(
          bv_utils.softmax_xent(
              logits=logits, labels=labels))
      return loss, (logits, loss, new_coll)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    aux, grads = grad_fn(params)
    _, loss, new_coll = aux[1]
    grads = jax.lax.pmean(grads, axis_name="batch")

    updates, opt = tx.update(grads, opt, params)
    params = optax.apply_updates(params, updates)

    return opt, params, new_coll, jax.lax.psum(loss, axis_name="batch"), rng

  cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, axis_name="batch"),
                                axis_name="batch")

  # eval step
  @partial(jax.pmap, axis_name="batch")
  def eval_step(params, coll, data, labels, mask):
    mask *= labels.max(axis=1)
    logits = Model(graph, constants).apply(
        flax.core.freeze({
            "params": params,
            **coll
        }),
        data,
        deterministic=True,
        training=False)
    loss = jnp.mean(
        bv_utils.softmax_xent(
            logits=logits, labels=labels))

    top1_idx = jnp.argmax(logits, axis=1)
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]
    correct = top1_correct * mask
    return (jax.lax.psum(correct, axis_name="batch"),
            jax.lax.psum(loss, axis_name="batch"),
            jax.lax.psum(mask, axis_name="batch"))

  def eval_model(params, coll, eval_it):
    total_correct = 0
    total_loss = 0
    total = 0
    eval_time = 0
    eval_start = time.time()
    for _, batch in zip(range(eval_steps), eval_it):
      correct, loss, neval = eval_step(params, coll, batch["image"],
                                       batch["labels"], batch["_mask"])
      total_correct += jnp.sum(correct[0])
      total_loss += jnp.sum(loss[0])
      total += jnp.sum(neval[0])
    if total: total.block_until_ready()
    eval_time += time.time() - eval_start
    return total_correct, total_loss, total, eval_time

  if eval_perf and is_host:
    num_params = perf_tools.compute_num_params(params_cpu)
    image_size = tuple(train_ds.element_spec["image"].shape[1:])
    dummy_input = jnp.zeros((1,) + image_size, jnp.float32)
    apply_fn = lambda v, inp: model.apply(  # pylint: disable=g-long-lambda
        v, inp, deterministic=True, training=False)
    flops = perf_tools.compute_num_flops(
        apply_fn,
        True,  # optimize
        flax.core.freeze({
            "params": params_cpu,
            **coll_cpu
        }), dummy_input)
    print(f"num_params: {num_params} | flops: {flops}")
  else:
    num_params = 0
    flops = 0

  im_sec_core_eval_measurements = np.array([])
  im_sec_core_train_measurements = np.array([])
  last_step = 0
  checkpoint_extra = dict(
      im_sec_core_eval_measurements=im_sec_core_eval_measurements,
      im_sec_core_train_measurements=im_sec_core_train_measurements,
      step=last_step)

  if config.output_dir is not None:
    checkpoint_path = f"{config.output_dir}/checkpoint.npz"
  else:
    checkpoint_path = None

  if trainable_params:
    tx, _ = bv_optax.make(train_config.optim, params_cpu, sched_kw=dict(
        global_batch_size=batch_size,
        total_steps=total_steps,
        steps_per_epoch=steps_per_epoch))
    opt_cpu = jax.jit(tx.init, backend="cpu")(trainable_params)

    # EMA
    ema_decay = train_config.get("ema_decay", 0)
    if ema_decay:
      end_warmup_step = train_config.get("ema_warmup_steps", 1560)
      ema_state_cpu = {"params": params_cpu, "coll": coll_cpu}
      ema_manager = train_utils.ExponentialMovingAverage(ema_state_cpu,
                                                         ema_decay,
                                                         end_warmup_step)

      @partial(jax.pmap, axis_name="batch")
      def update_ema(step, params, collection, ema):
        ema_state = {"params": params, "coll": collection}
        return ema.update_moving_average(ema_state, step)

    else:
      update_ema = ema_manager = ema_state_cpu = None

    # Load checkpoint if already exists
    if checkpoint_path and gfile.exists(checkpoint_path):
      checkpoint = {
          "opt": opt_cpu,
          "coll": coll_cpu,
          "params": params_cpu,
          "ema_state": ema_state_cpu,
          "extra": checkpoint_extra
      }
      checkpoint_tree = jax.tree_structure(checkpoint)
      loaded = bv_utils.load_checkpoint(checkpoint_tree, checkpoint_path)
      # bfloat16 type gets lost when data is saved to disk, so we recover it.
      checkpoint = jax.tree_map(bv_utils.recover_dtype, loaded)
      opt_cpu, coll_cpu, params_cpu, ema_state_cpu, checkpoint_extra = (
          checkpoint["opt"], checkpoint["coll"], checkpoint["params"],
          checkpoint["ema_state"], checkpoint["extra"])
      im_sec_core_eval_measurements = checkpoint_extra[
          "im_sec_core_eval_measurements"]
      im_sec_core_train_measurements = checkpoint_extra[
          "im_sec_core_train_measurements"]
      last_step = checkpoint_extra["step"]
      if ema_manager and ema_state_cpu:
        ema_manager = ema_manager.replace(state=ema_state_cpu)

      logging.info("Loaded checkpoint at step %d (%d total).", last_step,
                   total_steps)
  else:
    opt_cpu = None
    update_ema = ema_manager = None

  do_last_eval = True
  eval_is_compiled = False
  last_step = bv_optax.get_count(opt_cpu)
  if trainable_params and last_step < total_steps:
    trainable_params_repl = flax_utils.replicate(trainable_params)
    opt_repl = flax_utils.replicate(opt_cpu)
    coll_repl = flax_utils.replicate(coll_cpu)
    rng, rng_loop = jax.random.split(rng, 2)
    rngs_loop = flax_utils.replicate(rng_loop)
    frozen_repl = flax_utils.replicate(frozen_params)

    if ema_manager:
      ema_manager_repl = flax_utils.replicate(ema_manager)
    else:
      ema_manager_repl = None

    def ema_repl_to_state_cpu(ema_manager_repl):
      if ema_manager_repl is None:
        return None
      ema_trainable_params_repl = ema_manager_repl.state["params"]
      ema_trainable_params_cpu = jax.tree_map(lambda x: np.array(x[0]),
                                              ema_trainable_params_repl)
      ema_coll_repl = ema_manager_repl.state["coll"]
      ema_coll_cpu = jax.tree_map(lambda x: np.array(x[0]), ema_coll_repl)
      ema_state_cpu = {"params": ema_trainable_params_cpu,
                       "coll": ema_coll_cpu}
      return ema_state_cpu

    write_checkpoints = (
        is_host and checkpoint_path is not None and config.checkpoint_steps)

    step = last_step
    epoch = int(last_step / steps_per_epoch) + 1
    loss = 0
    train_time = 0
    checkpoint_writer = None

    if is_host:
      logging.info(
          "Training on dataset %s for %d total epochs (starting from %d).",
          train_config.dataset_name, train_config.epochs, epoch)

    for step, train_batch in zip(
        range(last_step + 1, total_steps + 1), train_iter):
      step_start = time.time()
      do_last_eval = True
      (opt_repl, trainable_params_repl, coll_repl, loss_repl,
       rngs_loop) = train_step(opt_repl, trainable_params_repl, frozen_repl,
                               coll_repl, train_batch["image"],
                               train_batch["labels"], rngs_loop)

      if update_ema is not None:
        step_repl = flax_utils.replicate(step)
        ema_manager_repl = update_ema(step_repl, trainable_params_repl,
                                      coll_repl, ema_manager_repl)

      loss += loss_repl[0]
      if step > steps_per_epoch * epoch or step == total_steps:
        line = (f"epoch {epoch:d}"
                f" | train loss {loss / steps_per_epoch:.1f}")
        if coll_cpu:
          coll_repl = cross_replica_mean(coll_repl)
        train_time += time.time() - step_start
        if epoch > 1:
          train_im = steps_per_epoch * batch_size
          im_sec_core_train = train_im / num_devices / train_time
          im_sec_core_train_measurements = np.append(
              im_sec_core_train_measurements, im_sec_core_train)

        if epoch % train_config.log_epochs == 0:
          if ema_manager_repl is not None:
            trainable_params_repl_eval = ema_manager_repl.state["params"]
            coll_repl_eval = ema_manager_repl.state["coll"]
          else:
            trainable_params_repl_eval = trainable_params_repl
            coll_repl_eval = coll_repl
          params_repl = {**trainable_params_repl_eval, **frozen_repl}
          correct, loss, n_eval, eval_time = eval_model(params_repl,
                                                        coll_repl_eval,
                                                        eval_it)
          if eval_is_compiled:
            eval_im = int(n_eval)
            im_sec_core_eval = eval_im / num_devices / eval_time
            im_sec_core_eval_measurements = np.append(
                im_sec_core_eval_measurements, im_sec_core_eval)
          eval_is_compiled = True
          line += (f" | val loss {loss:.2f}"
                   f" | val acc {correct / n_eval * 100:.3f}%"
                   f" ({int(correct)} / {int(n_eval)})")
          do_last_eval = False
          if step < total_steps and callback:
            metrics = Metrics(
                loss=loss,
                acc=correct / n_eval,
                num_params=num_params,
                flops=flops,
                im_sec_core_infer=(np.median(im_sec_core_eval_measurements) if
                                   len(im_sec_core_eval_measurements) else 0),
                im_sec_core_train=(np.median(im_sec_core_train_measurements) if
                                   len(im_sec_core_train_measurements) else 0))
            if not callback(epoch, metrics):
              line += " | EARLY STOPPED"
              if is_host:
                logging.info(line)
              break
        if is_host:
          logging.info(line)
          logging.info("Train measurements stddev: %.2f",
                       np.std(im_sec_core_train_measurements))
          logging.info("Eval measurements stddev: %.2f",
                       np.std(im_sec_core_eval_measurements))
        loss = 0
        epoch += 1
        train_time = 0
      train_time += time.time() - step_start
      if write_checkpoints and pool and step % config.checkpoint_steps == 0:
        assert pool is not None
        bv_utils.checkpointing_timeout(checkpoint_writer, 10)
        checkpoint_extra[
            "im_sec_core_eval_measurements"] = im_sec_core_eval_measurements
        checkpoint_extra[
            "im_sec_core_train_measurements"] = im_sec_core_train_measurements
        checkpoint_extra["step"] = step
        # We need to transfer the weights over now or else we risk keeping them
        # alive while they'll be updated in a future step, creating hard to
        # debug memory errors (see b/160593526). Also, takes device 0's params
        # only.
        opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
        coll_cpu = jax.tree_map(lambda x: np.array(x[0]), coll_repl)
        trainable_params_cpu = jax.tree_map(lambda x: np.array(x[0]),
                                            trainable_params_repl)
        params_cpu = {**trainable_params_cpu, **frozen_params}
        ema_state_cpu = ema_repl_to_state_cpu(ema_manager_repl)

        # Checkpoint should be a nested dictionary or FLAX datataclasses from
        # `flax.struct`. Both can be present in a checkpoint.
        checkpoint = {
            "opt": opt_cpu,
            "coll": coll_cpu,
            "params": params_cpu,
            "ema_state": ema_state_cpu,
            "extra": checkpoint_extra
        }
        checkpoint_writer = pool.apply_async(bv_utils.save_checkpoint,
                                             (checkpoint, checkpoint_path))
    coll_cpu = jax.tree_map(lambda x: np.array(x[0]), coll_repl)
    opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
    trainable_params_cpu = jax.tree_map(lambda x: np.array(x[0]),
                                        trainable_params_repl)
    params_cpu = {**trainable_params_cpu, **frozen_params}
    params_repl = {**trainable_params_repl, **frozen_repl}
    ema_state_cpu = ema_repl_to_state_cpu(ema_manager_repl)
    if ema_manager:
      coll_repl_eval = ema_manager_repl.state["coll"]
      params_repl_eval = {**ema_manager_repl.state["params"], **frozen_repl}
    else:
      coll_repl_eval = coll_repl
      params_repl_eval = params_repl
  else:
    epoch = 0
    coll_repl = flax_utils.replicate(coll_cpu)
    params_cpu = frozen_params
    params_repl = flax_utils.replicate(params_cpu)
    ema_state_cpu = None
    coll_repl_eval = coll_repl
    params_repl_eval = params_repl

  if do_last_eval:
    correct, loss, n_eval, eval_time = eval_model(params_repl_eval,
                                                  coll_repl_eval,
                                                  eval_it)
    if eval_is_compiled:
      eval_im = int(n_eval)
      im_sec_core_eval = eval_im / num_devices / eval_time
      im_sec_core_eval_measurements = np.append(im_sec_core_eval_measurements,
                                                im_sec_core_eval)
    eval_is_compiled = True

  if eval_perf and not len(im_sec_core_eval_measurements):  # pylint: disable=g-explicit-length-test (can't check len on numpy arrays)
    assert eval_is_compiled
    correct, loss, n_eval, eval_time = eval_model(params_repl_eval,
                                                  coll_repl_eval,
                                                  eval_it)
    eval_im = int(n_eval)
    im_sec_core_eval = eval_im / num_devices / eval_time
    im_sec_core_eval_measurements = np.append(im_sec_core_eval_measurements,
                                              im_sec_core_eval)

  checkpoint_extra[
      "im_sec_core_eval_measurements"] = im_sec_core_eval_measurements
  checkpoint_extra[
      "im_sec_core_train_measurements"] = im_sec_core_train_measurements
  checkpoint_extra["step"] = step
  checkpoint = {
      "opt": opt_cpu,
      "coll": coll_cpu,
      "params": params_cpu,
      "ema_state": ema_state_cpu,
      "extra": checkpoint_extra
  }
  if checkpoint_path is not None and is_host and pool:
    checkpoint_writer = pool.apply_async(bv_utils.save_checkpoint,
                                         (checkpoint, checkpoint_path))

  metrics = Metrics(
      loss=loss,
      acc=correct / n_eval,
      num_params=num_params,
      flops=flops,
      im_sec_core_infer=(np.median(im_sec_core_eval_measurements)
                         if len(im_sec_core_eval_measurements) else 0),
      im_sec_core_train=(np.median(im_sec_core_train_measurements)
                         if len(im_sec_core_train_measurements) else 0))

  if ema_state_cpu:
    state = ema_state_cpu
  else:
    state = {"coll": coll_cpu, "params": params_cpu}
  return metrics, epoch, state
class EnumSequentialTest(test.TestCase):
    def setUp(self):
        super().setUp()

        self.graph, self.constants, _ = cnn.CifarNet()
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state, self.input)["fc/logits"]
        self.max_size = int(10e8)

        self.hard = False

    def test_sequence_generator(self):
        """Test the sequence_generator function for correctness.

    seq_generator should generate
      [[0], [1], [2],
       [0,0], [0,1], [0,2], [1,0], [1,1], [1,2], ...,
       [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ...,]
    """
        def el_generator():
            i = 0
            while True:
                if i > 2:
                    return
                yield i
                i += 1

        seqs = list(sequence_generator(el_generator, 3))
        self.assertLen(seqs, 3 + 3**2 + 3**3)
        for i in range(len(seqs)):
            if i < 3:
                self.assertEqual(seqs[i], [i])
            elif i < 3 + 3**2:
                self.assertEqual(seqs[i], [(i - 3) // 3, i % 3])
            else:
                self.assertEqual(seqs[i], [(i - 12) // 3 // 3,
                                           (i - 12) // 3 % 3, i % 3])

    def test_kwargs_for_op_to_product(self):
        op_kwargs = {"a": [1, 2, 3], "b": [1, 2], "c": [5, 6]}
        input_kwargs = {"d": [1], "e": [1, 2], "f": [3, 4]}
        product = EnumerativeSequentialSynthesizer.kwargs_for_op_to_product(
            op_kwargs, input_kwargs)
        expected_length = 1
        for _, v in op_kwargs.items():
            expected_length *= len(v)
        for _, v in input_kwargs.items():
            expected_length *= len(v)
        self.assertLen(product, expected_length)

        op_setting = {"a": 2, "b": 2, "c": 5}
        input_setting = {"d": 1, "e": 2, "f": 3}
        self.assertIn((op_setting, input_setting), product)

    def _synthesize(self, subg, props):
        synthesizer = EnumerativeSequentialSynthesizer([(subg, props)],
                                                       0,
                                                       max_len=3)

        subg = synthesizer.synthesize()[0]

        m = Model(subg.graph, self.constants)
        state = m.init(random.PRNGKey(0), self.input)
        out = m.apply(state, self.input)["fc/logits"]
        self.assertTrue((out != self.out).any())

    def test_synthesizer_easy_one(self):
        """Replacing [conv3x3(features = 64)].

    Because we do not test linear, this is replaced by dense3x3(features = 64)
    due to the enumeration order.
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp])

    def test_synthesizer_easy_two(self):
        """Replacing [conv3x3(features = 64)].

    Because we test all three props, this is replaced by conv3x3(features = 64)
    (i.e., an identical op) due to the enumeration order.
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])

    def test_synthesizer_one(self):
        """Replacing [conv3x3(features = 64), ReLU].

    Because we do not check for the linear property, [dense(features = 64),
    ReLU] works as well (which is what is synthesized due to the enumeration
    order).
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:6]]
        subg[-1].output_names = self.graph.ops[6].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        # lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp])

    def test_synthesizer_two(self):
        """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)].

    Because we do not check for the depth property, [dense(features = 64),
    avgpool2x2(strides=2x2)] works as well (which is what is synthesized due to
    the enumeration order).
    """
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])

    def test_synthesizer_hard(self):
        if not self.hard:
            return
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])
class ProgSequentialTest(test.TestCase):
    def setUp(self):
        super().setUp()

        seed = int(time.time())
        logging.info("Seed: %d", seed)
        py_rand.seed(seed)

        self.graph, self.constants, _ = cnn.CifarNet()
        self.m = Model(self.graph, self.constants)
        self.input = {"input": jnp.ones((5, 32, 32, 3))}
        self.state = self.m.init(random.PRNGKey(0), self.input)
        self.out = self.m.apply(self.state, self.input)["fc/logits"]
        self.max_size = int(10e8)

        self.hard = False

    def _synthesize(self, subg, props):
        synthesizer = ProgressiveSequentialSynthesizer(
            [(subg, props)],
            generation=0,
            mode=ProgressiveSequentialSynthesizer.Mode.WEIGHTED,
            max_len=3)

        subg = synthesizer.synthesize()[0]
        subg_spec = subg.subgraph
        for node in subg_spec:
            print(node.op.name)
            print(node.output_names)

        m = Model(subg.graph, self.constants)
        state = m.init(random.PRNGKey(0), self.input)
        out = m.apply(state, self.input)["fc/logits"]
        self.assertTrue((out != self.out).any())

    def test_synthesizer_easy_one(self):
        """Replacing [conv3x3(features = 64)]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        # lp = linear.LinopProperty().infer(subgraph)
        self._synthesize(subgraph_model, [sp, dp])

    def test_synthesizer_easy_two(self):
        """Replacing [conv3x3(features = 64)]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:5]]
        subg[-1].output_names = self.graph.ops[5].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])

    def test_synthesizer_one(self):
        """Replacing [conv3x3(features = 64), ReLU]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:6]]
        subg[-1].output_names = self.graph.ops[6].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        # lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp])

    def test_synthesizer_two(self):
        """Replacing [conv3x3(features = 64), ReLU, avgpool2x2(strides=2x2)]."""
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        # dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, lp])

    def test_synthesizer_hard(self):
        if not self.hard:
            return
        subg = [subgraph.SubgraphNode(op=o) for o in self.graph.ops[4:7]]
        subg[-1].output_names = self.graph.ops[7].input_names
        subgraph_model = SubgraphModel(self.graph, self.constants, self.state,
                                       self.input, subg)
        sp = shape.ShapeProperty().infer(subgraph_model,
                                         max_size=self.max_size)
        dp = depth.DepthProperty().infer(subgraph_model)
        lp = linear.LinopProperty().infer(subgraph_model)
        self._synthesize(subgraph_model, [sp, dp, lp])