Exemplo n.º 1
0
def param_reduce(model, log=False):
  """Return a dict containing param counts per submodule."""
  params = get_params(model)
  sizes = collections.defaultdict(int)
  for path, x in tree.flatten_with_path(params):
    size = x.size
    for i in range(len(path)):
      k = path[:i]
      sizes[k] += size
  for k in sorted(sizes):
    if log:
      logging.info('%s: %s', k, sizes[k])
  return sizes
Exemplo n.º 2
0
def fast_map_structure_with_path(func, *structure):
    """Faster map_structure_with_path implementation."""
    head_entries_with_path = tree.flatten_with_path(structure[0])
    if len(structure) > 1:
        tail_entries = (tree.flatten(s) for s in structure[1:])
        entries_with_path = [
            e[0] + e[1:] for e in zip(head_entries_with_path, *tail_entries)
        ]
    else:
        entries_with_path = head_entries_with_path
    # Arbitrarily choose one of the structures of the original sequence (the last)
    # to match the structure for the flattened sequence.
    return tree.unflatten_as(structure[-1],
                             [func(*x) for x in entries_with_path])
Exemplo n.º 3
0
def loss_fn(forward, params, state, batch, l2=True):
    """Computes a regularized loss for the given batch."""
    logits, state = forward.apply(
        params, state, None, batch, is_training=True)
    labels = jax.nn.one_hot(batch[1], CLASS_NUM)
    logits = logits.reshape(len(labels), CLASS_NUM)  # match labels shape
    loss = optax.softmax_cross_entropy(logits=logits, labels=labels).mean()
    acc = (labels.argmax(1) == logits.argmax(1)).mean()

    if l2:
        l2_params = [p for ((mod_name, _), p) in tree.flatten_with_path(params)
                     if 'batchnorm' not in mod_name]
        loss = loss + 5e-4 * l2_loss(l2_params)
    return loss, (loss, state, acc)
Exemplo n.º 4
0
def loss_fn(
    params: hk.Params,
    state: hk.State,
    batch: dataset.Batch,
) -> Tuple[jnp.ndarray, hk.State]:
  """Computes a regularized loss for the given batch."""
  logits, state = forward.apply(params, state, None, batch, is_training=True)
  labels = jax.nn.one_hot(batch['labels'], 1000)
  cat_loss = jnp.mean(softmax_cross_entropy(logits=logits, labels=labels))
  l2_params = [p for ((mod_name, _), p) in tree.flatten_with_path(params)
               if 'batchnorm' not in mod_name]
  reg_loss = FLAGS.train_weight_decay * l2_loss(l2_params)
  loss = cat_loss + reg_loss
  return loss, state
Exemplo n.º 5
0
    def update_table(self, debug_namedtuple):
        nests = {
            'obs': debug_namedtuple.observations,
            'sample': debug_namedtuple.samples,
            'gamestate': debug_namedtuple.gamestate,
            'dist': debug_namedtuple.distances
        }
        path_delim = '.'

        df = None
        for nest_name, nest in nests.items():

            # E.g., skip dist for controller_heads that do not yet provide controller component-wise distance
            if nest_name == 'dist' and not tf.nest.is_nested(nest):
                continue

            for path, array in tree.flatten_with_path(nest):
                leaf_name = path_delim.join([str(subpath) for subpath in path])

                wide_single_df = pd.DataFrame(
                    tf.squeeze(array).numpy().astype(float))
                wide_single_df['Time'] = wide_single_df.index
                col_nm = f'{nest_name} {leaf_name}'
                long_df = pd.melt(wide_single_df,
                                  id_vars=['Time'],
                                  var_name='Batch',
                                  value_name=col_nm)

                if df is not None:
                    df[col_nm] = long_df[col_nm]
                else:
                    df = long_df

        df['Time'] += self.time
        self.time += len(wide_single_df)
        df['Loss'] = df[[col for col in df.columns
                         if 'dist ' in col]].agg(np.nansum, axis='columns')

        self.mdf = self.mdf.append(df)

        done = False
        if len(self.mdf) > self.table_length:
            out_fn = self.saved_model_path + f'{self.table_name}.csv'
            self.mdf.to_csv(out_fn)
            print(f'Wrote {len(self.mdf)} frames to {out_fn}')
            done = True
        return done
Exemplo n.º 6
0
def flatten_with_name(structure: Any) -> List[Tuple[str, Any]]:
  """Creates a flattened representation of the `structure` with names.

  Args:
    structure: A potentially nested structure.

  Returns:
    A `list` of `(name, value)` `tuples` representing the flattened `structure`,
    where `name` uniquely identifies the position of the `value` in the
    `structure`.
  """
  flattened = tree.flatten_with_path(structure)

  def _name(path: Iterable[Any]) -> str:
    return '/'.join(map(str, path))

  return [(_name(path), value) for path, value in flattened]
Exemplo n.º 7
0
def _flatten_nested_dict(struct: Dict[str, Any]) -> Dict[str, Any]:
  """Flattens a given nested structure of tensors, sorting by flattened keys.

  For example, if we have the nested dictionary {'d':3, 'a': {'b': 1, 'c':2}, },
  this will produce the (ordered) dictionary {'a/b': 1, 'a/c': 2, 'd': 3}. This
  will unpack lists, so that {'a': [3, 4, 5]} will be flattened to the ordered
  dictionary {'a/0': 3, 'a/1': 4, 'a/2': 5}. The resulting values of the
  flattened dictionary will be the leaf nodetensors in the original struct.

  Args:
    struct: A nested dictionary.

  Returns:
    A `collections.OrderedDict` representing a flattened version of `struct`.
  """
  flat_struct = tree.flatten_with_path(struct)
  flat_struct = [('/'.join(map(str, path)), item) for path, item in flat_struct]
  return collections.OrderedDict(sorted(flat_struct))
Exemplo n.º 8
0
def loss_fn(
    params: hk.Params,
    state: hk.State,
    loss_scale: jmp.LossScale,
    batch: dataset.Batch,
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, hk.State]]:
    """Computes a regularized loss for the given batch."""
    logits, state = forward.apply(params, state, None, batch, is_training=True)
    labels = jax.nn.one_hot(batch['labels'], 1000)
    if FLAGS.train_smoothing:
        labels = optax.smooth_labels(labels, FLAGS.train_smoothing)
    loss = optax.softmax_cross_entropy(logits=logits, labels=labels).mean()
    l2_params = [
        p for ((mod_name, _), p) in tree.flatten_with_path(params)
        if 'batchnorm' not in mod_name
    ]
    loss = loss + FLAGS.train_weight_decay * l2_loss(l2_params)
    return loss_scale.scale(loss), (loss, state)
Exemplo n.º 9
0
    def eval_hook(generator, discriminator, server_state, round_num):
        """Called during TFF GAN IterativeProcess, to compute eval metrics."""
        start_time = time.time()
        metrics = {}

        gen_inputs = next(gen_inputs_iter)
        real_images = next(real_images_iter)

        if round_num % rounds_per_save_images == 0:
            _save_images(
                generator,
                gen_inputs,
                outdir=path_to_output_images,
                file_prefix='emnist_tff_gen_images_step_{:05d}'.format(
                    round_num))

        # Compute eval metrics.
        eval_metrics = _compute_eval_metrics(generator, discriminator,
                                             gen_inputs, real_images,
                                             gan_loss_fns,
                                             emnist_classifier_for_metrics)
        metrics['eval'] = eval_metrics

        # Get counters from the server_state.
        metrics['counters'] = server_state.counters

        # Write metrics to a tf.summary logdir.
        flat_metrics = tree.flatten_with_path(metrics)
        flat_metrics = [('/'.join(map(str, path)), item)
                        for path, item in flat_metrics]
        flat_metrics = collections.OrderedDict(flat_metrics)
        with summary_writer.as_default():
            for name, value in flat_metrics.items():
                tf.summary.scalar(name, value, step=round_num)

        # Print out the counters, and log how long it took to compute/write metrics.
        for k, v in server_state.counters.items():
            print('{:>40s} {:8.0f}'.format(k, v), flush=True)
        logging.info('Doing evaluation took %.2f seconds.',
                     time.time() - start_time)
Exemplo n.º 10
0
def flatten_with_joined_string_paths(structure, separator='/'):
  """Returns a list of (string path, data element) tuples.

  The order of tuples produced matches that of `nest.flatten`. This allows you
  to flatten a nested structure while keeping information about where in the
  structure each data element was located. See `nest.yield_flat_paths`
  for more information.

  Args:
    structure: the nested structure to flatten.
    separator: string to separate levels of hierarchy in the results, defaults
      to '/'.

  Returns:
    A list of (string, data element) tuples.
  """

  def stringify_and_join(path_elements):
    return separator.join(str(path_element) for path_element in path_elements)

  return [(stringify_and_join(pe), v)
          for pe, v in dm_tree.flatten_with_path(structure)]
Exemplo n.º 11
0
    def __call__(self, train_metrics, eval_metrics, round_num):
        """A function suitable for passing as an eval hook to the training_loop.

    Args:
      train_metrics: A `dict` of training metrics computed in TFF.
      eval_metrics: A `dict` of evalutation metrics computed in TFF.
      round_num: The current round number.
    """
        metrics = {
            'train': train_metrics,
            'eval': eval_metrics,
            'round': round_num,
        }
        flat_metrics = tree.flatten_with_path(metrics)
        flat_metrics = [('/'.join(map(str, path)), item)
                        for path, item in flat_metrics]
        flat_metrics = collections.OrderedDict(flat_metrics)

        logging.info('Evaluation at round {:d}:\n{!s}'.format(
            round_num, pprint.pformat(flat_metrics)))

        # Also write metrics to a tf.summary logdir
        with self._summary_writer.as_default():
            for name, value in flat_metrics.items():
                tf.compat.v2.summary.scalar(name, value, step=round_num)

        if tf.io.gfile.exists(self._results_file):
            metrics = pd.read_csv(self._results_file,
                                  header=0,
                                  index_col=0,
                                  engine='c')
            # Remove everything after `round_num`, in case the experiment was
            # restarted at an earlier checkpoint we want to avoid duplicate metrics.
            metrics = metrics[:round_num]
            metrics = metrics.append(flat_metrics, ignore_index=True)
        else:
            metrics = pd.DataFrame(flat_metrics, index=[0])
        utils_impl.atomic_write_to_csv(metrics, self._results_file)
Exemplo n.º 12
0
    def append(self, data: Any, *, partial_step: bool = False):
        """Columnwise append of data leaf nodes to internal buffers.

    If `data` includes fields or sub structures which haven't been present in
    any previous calls then the types and shapes of the new fields are extracted
    and used to validate future `append` calls. The structure of `history` is
    also updated to include the union of the structure across all `append`
    calls.

    When new fields are added after the first step then the newly created
    history field will be filled with `None` in all preceding positions. This
    results in the equal indexing across columns. That is `a[i]` and `b[i]`
    references the same step in the sequence even if `b` was first observed
    after `a` had already been seen.

    It is possible to create a "step" using more than one `append` call by
    setting the `partial_step` flag. Partial steps can be used when some parts
    of the step becomes available only as a result of inserting (and learning
    from) trajectories that include the fields available first (e.g learn from
    the SARS trajectory to select the next action in an on-policy agent). In the
    final `append` call of the step, `partial_step` must be set to False.
    Failing to "close" the partial step will result in error as the same field
    must NOT be provided more than once in the same step.

    Args:
      data: The (possibly nested) structure to make available for new items to
        reference.
      partial_step: If `True` then the step is not considered "done" with this
        call. See above for more details. Defaults to `False`.

    Returns:
      References to the data structured just like provided `data`.

    Raises:
      ValueError: If the same column is provided more than once in the same
        step.
    """
        # Unless it is the first step, check that the structure is the same.
        if self._structure is None:
            self._update_structure(tree.map_structure(lambda _: None, data))

        data_with_path_flat = tree.flatten_with_path(data)
        try:
            # Use our custom mapping to flatten the expanded structure into columns.
            flat_column_data = self._reorder_like_flat_structure(
                data_with_path_flat)
        except KeyError:
            # `data` contains fields which haven't been observed before so we need
            # expand the spec using the union of the history and `data`.
            self._update_structure(
                _tree_union(self._structure,
                            tree.map_structure(lambda x: None, data)))

            flat_column_data = self._reorder_like_flat_structure(
                data_with_path_flat)

        # If the last step is still open then verify that already populated columns
        # are None in the new `data`.
        if self._last_step_is_open:
            for i, (column, column_data) in enumerate(
                    zip(self._column_history, flat_column_data)):
                if column_data is None or column.can_set_last:
                    continue

                raise ValueError(
                    f'Field {self._get_path_for_column_index(i)} has already been set '
                    f'in the active step by previous (partial) append call and thus '
                    f'must be omitted or set to None but got: {column_data}')

        # Flatten the data and pass it to the C++ writer for column wise append. In
        # all columns where data is provided (i.e not None) will return a reference
        # to the data (`pybind.WeakCellRef`) which is used to define trajectories
        # for `create_item`. The columns which did not receive a value (i.e None)
        # will return None.
        if partial_step:
            flat_column_data_references = self._writer.AppendPartial(
                flat_column_data)
        else:
            flat_column_data_references = self._writer.Append(flat_column_data)

        # Append references to respective columns. Note that we use the expanded
        # structure in order to populate the columns missing from the data with
        # None.
        for column, data_reference in zip(self._column_history,
                                          flat_column_data_references):
            # If the last step is still open (i.e `partial_step` was set) then we
            # populate that step instead of creating a new one.
            if not self._last_step_is_open:
                column.append(data_reference)
            elif data_reference is not None:
                column.set_last(data_reference)

        # Save the flag so the next `append` call either populates the same step
        # or begins a new step.
        self._last_step_is_open = partial_step

        # Unpack the column data into the expanded structure.
        expanded_structured_data_references = self._unflatten(
            flat_column_data_references)

        # Return the referenced structured in the same way as `data`. If only a
        # subset of the fields were present in the input data then only these fields
        # will exist in the output.
        filtered_data_references_flat = _tree_filter(
            expanded_structured_data_references, data_with_path_flat)
        return tree.unflatten_as(data, filtered_data_references_flat)
Exemplo n.º 13
0
def flatten_with_tuple_paths(structure, expand_composites=False):
  if expand_composites:
    raise NotImplementedError(
        '`expand_composites=True` is not supported in JAX.')
  return dm_tree.flatten_with_path(structure)
Exemplo n.º 14
0
    def testFlattenWithPath(self, test_type):
        self._init_testdata(test_type)

        self.assertEqual(tree.flatten_with_path(self.dcls_with_map),
                         self.dcls_flattened_with_path)
Exemplo n.º 15
0
 def _flatten(self, data):
     flat_data = [None] * len(self._path_to_column_index)
     for path, value in tree.flatten_with_path(data):
         flat_data[self._path_to_column_index[path]] = value
     return flat_data
Exemplo n.º 16
0
def flatten_with_tuple_paths(structure):
  return flatten_with_path(structure)
Exemplo n.º 17
0
    def testFlattenWithPathUpTo(self):
        def get_paths_and_values(shallow_tree, input_tree):
            path_value_pairs = tree.flatten_with_path_up_to(
                shallow_tree, input_tree)
            paths = [p for p, _ in path_value_pairs]
            values = [v for _, v in path_value_pairs]
            return paths, values

        # Shallow tree ends at scalar.
        input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
        shallow_tree = [[True, True], [False, True]]
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [(0, 0), (0, 1), (1, 0),
                                                      (1, 1)])
        self.assertEqual(flattened_input_tree,
                         [[2, 2], [3, 3], [4, 9], [5, 5]])
        self.assertEqual(flattened_shallow_tree_paths, [(0, 0), (0, 1), (1, 0),
                                                        (1, 1)])
        self.assertEqual(flattened_shallow_tree, [True, True, False, True])

        # Shallow tree ends at string.
        input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
        shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
        (input_tree_flattened_as_shallow_tree_paths,
         input_tree_flattened_as_shallow_tree) = get_paths_and_values(
             shallow_tree, input_tree)
        input_tree_flattened_paths = [
            p for p, _ in tree.flatten_with_path(input_tree)
        ]
        input_tree_flattened = tree.flatten(input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
                         [(0, 0), (0, 1, 0), (0, 1, 1, 0), (0, 1, 1, 1, 0)])
        self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1),
                                                                ("b", 2),
                                                                ("c", 3),
                                                                ("d", 4)])

        self.assertEqual(input_tree_flattened_paths, [(0, 0, 0), (0, 0, 1),
                                                      (0, 1, 0, 0),
                                                      (0, 1, 0, 1),
                                                      (0, 1, 1, 0, 0),
                                                      (0, 1, 1, 0, 1),
                                                      (0, 1, 1, 1, 0, 0),
                                                      (0, 1, 1, 1, 0, 1)])
        self.assertEqual(input_tree_flattened,
                         ["a", 1, "b", 2, "c", 3, "d", 4])

        # Make sure dicts are correctly flattened, yielding values, not keys.
        input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
        shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
        (input_tree_flattened_as_shallow_tree_paths,
         input_tree_flattened_as_shallow_tree) = get_paths_and_values(
             shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
                         [("a", ), ("b", ), ("d", 0), ("d", 1)])
        self.assertEqual(input_tree_flattened_as_shallow_tree,
                         [1, {
                             "c": 2
                         }, 3, (4, 5)])

        # Namedtuples.
        ab_tuple = collections.namedtuple("ab_tuple", "a, b")
        input_tree = ab_tuple(a=[0, 1], b=2)
        shallow_tree = ab_tuple(a=0, b=1)
        (input_tree_flattened_as_shallow_tree_paths,
         input_tree_flattened_as_shallow_tree) = get_paths_and_values(
             shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", ),
                                                                      ("b", )])
        self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2])

        # Nested dicts, OrderedDicts and namedtuples.
        input_tree = collections.OrderedDict([
            ("a", ab_tuple(a=[0, {
                "b": 1
            }], b=2)), ("c", {
                "d": 3,
                "e": collections.OrderedDict([("f", 4)])
            })
        ])
        shallow_tree = input_tree
        (input_tree_flattened_as_shallow_tree_paths,
         input_tree_flattened_as_shallow_tree) = get_paths_and_values(
             shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
                         [("a", "a", 0), ("a", "a", 1, "b"), ("a", "b"),
                          ("c", "d"), ("c", "e", "f")])
        self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
        shallow_tree = collections.OrderedDict([("a", 0),
                                                ("c", {
                                                    "d": 3,
                                                    "e": 1
                                                })])
        (input_tree_flattened_as_shallow_tree_paths,
         input_tree_flattened_as_shallow_tree) = get_paths_and_values(
             shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
                         [("a", ), ("c", "d"), ("c", "e")])
        self.assertEqual(input_tree_flattened_as_shallow_tree, [
            ab_tuple(a=[0, {
                "b": 1
            }], b=2), 3,
            collections.OrderedDict([("f", 4)])
        ])
        shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
        (input_tree_flattened_as_shallow_tree_paths,
         input_tree_flattened_as_shallow_tree) = get_paths_and_values(
             shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", ),
                                                                      ("c", )])
        self.assertEqual(input_tree_flattened_as_shallow_tree, [
            ab_tuple(a=[0, {
                "b": 1
            }], b=2), {
                "d": 3,
                "e": collections.OrderedDict([("f", 4)])
            }
        ])

        ## Shallow non-list edge-case.
        # Using iterable elements.
        input_tree = ["input_tree"]
        shallow_tree = "shallow_tree"
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [()])
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree_paths, [()])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = ["input_tree_0", "input_tree_1"]
        shallow_tree = "shallow_tree"
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [()])
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree_paths, [()])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Test case where len(shallow_tree) < len(input_tree)
        input_tree = {"a": "A", "b": "B", "c": "C"}
        shallow_tree = {"a": 1, "c": 2}

        with self.assertRaisesWithLiteralMatch(  # pylint: disable=g-error-prone-assert-raises
                ValueError,
                tree._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
                    input_length=len(input_tree),
                    shallow_length=len(shallow_tree))):
            get_paths_and_values(shallow_tree, input_tree)

        # Using non-iterable elements.
        input_tree = [0]
        shallow_tree = 9
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [()])
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree_paths, [()])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = [0, 1]
        shallow_tree = 9
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [()])
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree_paths, [()])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Both non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = "shallow_tree"
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [()])
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree_paths, [()])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = 0
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [()])
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree_paths, [()])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Input non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = ["shallow_tree"]
        with self.assertRaisesWithLiteralMatch(
                TypeError,
                tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
                    type(input_tree))):
            (flattened_input_tree_paths,
             flattened_input_tree) = get_paths_and_values(
                 shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree_paths, [(0, )])
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = "input_tree"
        shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
        with self.assertRaisesWithLiteralMatch(
                TypeError,
                tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
                    type(input_tree))):
            (flattened_input_tree_paths,
             flattened_input_tree) = get_paths_and_values(
                 shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree_paths, [(0, ), (1, )])
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = [9]
        with self.assertRaisesWithLiteralMatch(
                TypeError,
                tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
                    type(input_tree))):
            (flattened_input_tree_paths,
             flattened_input_tree) = get_paths_and_values(
                 shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree_paths, [(0, )])
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = 0
        shallow_tree = [9, 8]
        with self.assertRaisesWithLiteralMatch(
                TypeError,
                tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
                    type(input_tree))):
            (flattened_input_tree_paths,
             flattened_input_tree) = get_paths_and_values(
                 shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree_paths, [(0, ), (1, )])
        self.assertEqual(flattened_shallow_tree, shallow_tree)
Exemplo n.º 18
0
def _graph_to_dot(graph: Graph, args, outputs) -> str:
    """Converts from an internal graph IR to 'dot' format."""
    if tree is None:
        raise ImportError('hk.experimental.to_dot requires dm-tree>=0.1.1.')

    def format_path(path):
        if isinstance(outputs, tuple):
            out = f'output[{path[0]}]'
            if len(path) > 1:
                out += ': ' + '/'.join(map(str, path[1:]))
        else:
            out = 'output'
            if path:
                out += ': ' + '/'.join(map(str, path))
        return out

    lines = []
    used_argids = set()
    argid_usecount = collections.Counter()
    op_outids = set()
    captures = []
    argids = {id(v) for v in jax.tree_leaves(args)}
    outids = {id(v) for v in jax.tree_leaves(outputs)}
    outname = {
        id(v): format_path(p)
        for p, v in tree.flatten_with_path(outputs)
    }

    def render_graph(g: Graph, parent: Optional[Graph] = None, depth: int = 0):
        """Renders a given graph by appending 'dot' format lines."""

        if parent:
            lines.extend([
                f'subgraph cluster_{id(g)} {{',
                '  style="rounded,filled";',
                '  fillcolor="#F0F5F5";',
                '  color="#14234B;";',
                '  pad=0.1;',
                f'  fontsize={_scaled_font_size(depth)};',
                f'  label = <<b>{escape(g.title)}</b>>;',
                '  labelloc = t;',
            ])

        for node in g.nodes:
            label = f'<b>{escape(node.title)}</b>'
            for o in node.outputs:
                label += '<br/>' + _format_val(o)
                op_outids.add(id(o))

            node_id = id(node.id)
            if node_id in outids:
                label = f'<b>{escape(outname[node_id])}</b><br/>' + label
                color = '#0053D6'
                fillcolor = '#AABFFF'
                style = 'filled,bold'
            else:
                color = '#FFDB13'
                fillcolor = '#FFF26E'
                style = 'filled'

            lines.append(f'{node_id} [label=<{label}>, '
                         f' id="node{node_id}",'
                         ' shape=rect,'
                         f' style="{style}",'
                         ' tooltip=" ",'
                         ' fontcolor="black",'
                         f' color="{color}",'
                         f' fillcolor="{fillcolor}"];')

        for s in g.subgraphs:
            render_graph(s, parent=g, depth=depth - 1)

        if parent:
            lines.append(f'}}  // subgraph cluster_{id(g)}')

        for a, b in g.edges:
            if id(a) not in argids and id(a) not in op_outids:
                captures.append(a)

            a, b = map(id, (a, b))
            if a in argids:
                i = argid_usecount[a]
                argid_usecount[a] += 1
                lines.append(f'{a}{i} -> {b};')
            else:
                lines.append(f'{a} -> {b};')
            used_argids.add(a)

    graph_depth = _max_depth(graph)
    render_graph(graph, parent=None, depth=graph_depth)

    # Process inputs and label them in the graph.
    for path, value in tree.flatten_with_path(args):
        if value is None:
            continue

        node_id = id(value)
        if node_id not in used_argids:
            continue

        for i in range(argid_usecount[node_id]):
            label = f'<b>args[{escape(path[0])}]'
            if len(path) > 1:
                label += ': ' + '/'.join(map(str, path[1:]))
            label += '</b>'
            if hasattr(value, 'shape') and hasattr(value, 'dtype'):
                label += f'<br/>{escape(_format_val(value))}'
            fillcolor = '#FFDEAF'
            fontcolor = 'black'

            if i > 0:
                label = '<b>(reuse)</b><br/>' + label
                fillcolor = '#FFEACC'
                fontcolor = '#565858'

            lines.append(f'{node_id}{i} [label=<{label}>'
                         f' id="node{node_id}{i}",'
                         ' shape=rect,'
                         ' style="filled",'
                         f' fontcolor="{fontcolor}",'
                         ' color="#FF8A4F",'
                         f' fillcolor="{fillcolor}"];')

    for value in captures:
        node_id = id(value)
        if (not hasattr(value, 'aval') and hasattr(value, 'size')
                and value.size == 1):
            label = f'<b>{value.item()}</b>'
        else:
            label = f'<b>{escape(_format_val(value))}</b>'

        lines.append(f'{node_id} [label=<{label}>'
                     ' shape=rect,'
                     ' style="filled",'
                     ' fontcolor="black",'
                     ' color="#A261FF",'
                     ' fillcolor="#E6D6FF"];')

    head = [
        'digraph G {',
        'rankdir = TD;',
        'compound = true;',
        f'label = <<b>{escape(graph.title)}</b>>;',
        f'fontsize={_scaled_font_size(graph_depth)};',
        'labelloc = t;',
        'stylesheet = <',
        '  data:text/css,',
        '  @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);',
        '  svg text {',
        '    font-family: \'Roboto\';',
        '  }',
        '  .node text {',
        '    font-size: 12px;',
        '  }',
    ]
    for node_id, use_count in argid_usecount.items():
        if use_count == 1:
            continue
        # Add hover animation for reused args.
        for a in range(use_count):
            for b in range(use_count):
                if a == b:
                    head.append(f'%23node{node_id}{a}:hover '
                                '{ stroke-width: 0.2em; }')
                else:
                    head.append(
                        f'%23node{node_id}{a}:hover ~ %23node{node_id}{b} '
                        '{ stroke-width: 0.2em; }')
    head.append('>')

    lines.append('} // digraph G')
    return '\n'.join(head + lines) + '\n'
Exemplo n.º 19
0
def secure_sum_then_finalize(
    metric_finalizers: model_lib.MetricFinalizersType,
    local_unfinalized_metrics_type: computation_types.StructWithPythonType,
    metric_value_ranges: Optional[MetricValueRangeDict] = None
) -> computation_base.Computation:
    """Creates a TFF computation that aggregates metrics using secure summation.

  The returned federated TFF computation has the following type signature:

  ```
  (local_unfinalized_metrics@CLIENTS ->
   <aggregated_metrics@SERVER, secure_sum_measurements@SERVER)
  ```

  where the input is given by
  `tff.learning.Model.report_local_unfinalized_metrics()` at `CLIENTS`, and the
  first output (`aggregated_metrics`) is computed by first securely summing the
  unfinalized metrics from `CLIENTS`, followed by applying the finalizers at
  `SERVER`. The second output (`secure_sum_measurements`) is an `OrderedDict`
  that maps from `factory_key`s to the secure summation measurements (e.g. the
  number of clients gets clipped. See `tff.aggregators.SecureSumFactory` for
  details). A `factory_key` is uniquely defined by three scalars: lower bound,
  upper bound, and tensor dtype (denoted as datatype enum). Metric values of the
  same `factory_key` are grouped and aggegrated together (and hence, the
  `secure_sum_measurements` are also computed at a group level).

  Since secure summation works in fixed-point arithmetic space, floating point
  numbers must be encoding using integer quantization. By default, each tensor
  in `local_unfinalized_metrics_type` will be clipped to `[0, 2**20 - 1]` and
  encoded to integers inside `tff.aggregators.SecureSumFactory`. Callers can
  change this range by setting `metric_value_ranges`, which may be a partial
  tree matching the structure of `local_unfinalized_metrics_type`.

  Example partial value range specification:

  >>> finalizers = ...
  >>> metrics_type = tff.to_type(collections.OrderedDict(
      a=tff.types.TensorType(tf.int32),
      b=tff.types.TensorType(tf.float32),
      c=[tff.types.TensorType(tf.float32), tff.types.TensorType(tf.float32)])
  >>> value_ranges = collections.OrderedDict(
      b=(0.0, 1.0),
      c=[None, (0.0, 1.0)])
  >>> aggregator = tff.learning.metrics.secure_sum_then_finalize(
      finalizers, metrics_type, value_ranges)

  This sets the range of the *second* tensor of `b` in the dictionary, using the
  range for the first tensor, and the `a` tensor.

  Args:
    metric_finalizers: An `OrderedDict` of `string` metric names to finalizer
      functions returned by `tff.learning.Model.metric_finalizers()`. It should
      have the same keys (i.e., metric names) as the `OrderedDict` returned by
      `tff.learning.Model.report_local_unfinalized_metrics()`. A finalizer is a
      callable (typically `tf.function` or `tff.tf_computation` decoreated
      function) that takes in a metric's unfinalized values, and returns the
      finalized values.
    local_unfinalized_metrics_type: A `tff.types.StructWithPythonType` (with
      `OrderedDict` as the Python container) of a client's local unfinalized
      metrics. Let `local_unfinalized_metrics` be the output of
      `tff.learning.Model.report_local_unfinalized_metrics()`. Its type can be
      obtained by `tff.framework.type_from_tensors(local_unfinalized_metrics)`.
    metric_value_ranges: A `collections.OrderedDict` that matches the structure
      of `local_unfinalized_metrics_type` (a value for each
      `tff.types.TensorType` in the type tree). Each leaf in the tree should
      have a 2-tuple that defines the range of expected values for that variable
      in the metric. If the entire structure is `None`, a default range of
      `[0.0, 2.0**20 - 1]` will be applied to all variables. Each leaf may also
      be `None`, which will also get the default range; allowing partial user
      sepcialization. At runtime, values that fall outside the ranges specified
      at the leaves, those values will be clipped to within the range.

  Returns:
    A federated TFF computation that securely sums the unfinalized metrics from
    `CLIENTS`, and applies the correponding finalizers at `SERVER`.

  Raises:
    TypeError: If the inputs are of the wrong types.
    ValueError: If the keys (i.e., metric names) in `metric_finalizers` are not
      the same as those expected by `local_unfinalized_metrics_type`.
  """
    check_metric_finalizers(metric_finalizers)
    check_local_unfinalzied_metrics_type(local_unfinalized_metrics_type)
    check_finalizers_matches_unfinalized_metrics(
        metric_finalizers, local_unfinalized_metrics_type)

    default_metric_value_ranges = create_default_secure_sum_quantization_ranges(
        local_unfinalized_metrics_type)
    if metric_value_ranges is None:
        metric_value_ranges = default_metric_value_ranges

    # Walk the incoming `metric_value_ranges` and `default_metric_value_ranges`
    # and fill in any missing ranges using the defaults.
    def fill_missing_values_with_defaults(default_values, user_values):
        if isinstance(default_values, collections.abc.Mapping):
            if user_values is None:
                user_values = {}
            return type(default_values)(
                (key,
                 fill_missing_values_with_defaults(default_value,
                                                   user_values.get(key)))
                for key, default_value in default_values.items())
        elif isinstance(default_values, list):
            if user_values is None:
                user_values = [None] * len(default_values)
            return [
                fill_missing_values_with_defaults(default_value,
                                                  user_values[idx])
                for idx, default_value in enumerate(default_values)
            ]
        elif user_values is None:
            return _MetricRange(*default_values)
        else:
            _check_range(user_values)
            return _MetricRange(*user_values)

    try:
        metric_value_ranges = fill_missing_values_with_defaults(
            default_metric_value_ranges, metric_value_ranges)
    except TypeError as e:
        raise TypeError('Failed to create encoding value range from: '
                        f'{metric_value_ranges}') from e

    # Create an aggregator factory for each unique value range, rather than each
    # leaf tensor (which could introduce a lot of duplication).
    aggregator_factories = {
        value_range: secure.SecureSumFactory(value_range.upper,
                                             value_range.lower)
        for value_range in set(tree.flatten(metric_value_ranges))
    }
    # Construct a python container of `tff.TensorType` so we can traverse it in
    # parallel with the value ranges during AggregationProcess construction.
    # Otherwise we have a `tff.Type` but `metric_value_ranges` is a Python
    # container which are difficult to traverse in parallel.
    structure_of_tensor_types = type_conversions.structure_from_tensor_type_tree(
        lambda t: t, local_unfinalized_metrics_type)

    # We will construct groups of tensors with the same dtype and quantization
    # value range so that we can construct fewer aggregations-of-structures,
    # rather than a large structure-of-aggregations. Without this, the TFF
    # compiler pipeline results in large slow downs (see b/218312198).
    factory_key_by_path = collections.OrderedDict()
    value_range_by_factory_key = collections.OrderedDict()
    path_list_by_factory_key = collections.defaultdict(list)
    # Maintain a flattened list of paths. This is useful to flatten the aggregated
    # values, which will then be used by `tf.nest.pack_sequence_as`.
    flattened_path_list = []
    for (path, tensor_spec), (_, value_range) in zip(
            tree.flatten_with_path(structure_of_tensor_types),
            tree.flatten_with_path(metric_value_ranges)):
        factory_key = _create_factory_key(value_range.lower, value_range.upper,
                                          tensor_spec.dtype)
        factory_key_by_path[path] = factory_key
        value_range_by_factory_key[factory_key] = value_range
        path_list_by_factory_key[factory_key].append(path)
        flattened_path_list.append(path)

    @tensorflow_computation.tf_computation(local_unfinalized_metrics_type)
    def group_value_by_factory_key(local_unfinalized_metrics):
        """Groups client local metrics into a map of `factory_key` to value list."""
        # We cannot use `collections.defaultdict(list)` here because its result is
        # incompatible with `structure_from_tensor_type_tree`.
        value_list_by_factory_key = collections.OrderedDict()
        for path, value in tree.flatten_with_path(local_unfinalized_metrics):
            factory_key = factory_key_by_path[path]
            if factory_key in value_list_by_factory_key:
                value_list_by_factory_key[factory_key].append(value)
            else:
                value_list_by_factory_key[factory_key] = [value]
        return value_list_by_factory_key

    def flatten_grouped_values(value_list_by_factory_key):
        """Flatten the values in the same order as in `flattened_path_list`."""
        value_by_path = collections.OrderedDict()
        for factory_key in value_list_by_factory_key:
            path_list = path_list_by_factory_key[factory_key]
            value_list = value_list_by_factory_key[factory_key]
            for path, value in zip(path_list, value_list):
                value_by_path[path] = value
        flattened_value_list = [
            value_by_path[path] for path in flattened_path_list
        ]
        return flattened_value_list

    # Create a aggregation process for each factory key.
    aggregation_process_by_factory_key = collections.OrderedDict()
    # Construct a python container of `tff.TensorType` so we can traverse it and
    # create aggregation processes from the factories.
    tensor_type_list_by_factory_key = (
        type_conversions.structure_from_tensor_type_tree(
            lambda t: t, group_value_by_factory_key.type_signature.result))
    for factory_key, tensor_type_list in tensor_type_list_by_factory_key.items(
    ):
        value_range = value_range_by_factory_key[factory_key]
        aggregation_process_by_factory_key[
            factory_key] = aggregator_factories.get(value_range).create(
                computation_types.to_type(tensor_type_list))

    @federated_computation.federated_computation(
        computation_types.at_clients(local_unfinalized_metrics_type))
    def aggregator_computation(client_local_unfinalized_metrics):
        unused_state = intrinsics.federated_value((), placements.SERVER)

        client_local_grouped_unfinalized_metrics = intrinsics.federated_map(
            group_value_by_factory_key, client_local_unfinalized_metrics)
        metrics_aggregation_output = collections.OrderedDict()
        for factory_key, process in aggregation_process_by_factory_key.items():
            metrics_aggregation_output[factory_key] = process.next(
                unused_state,
                client_local_grouped_unfinalized_metrics[factory_key])

        metrics_aggregation_output = intrinsics.federated_zip(
            metrics_aggregation_output)

        @tensorflow_computation.tf_computation(
            metrics_aggregation_output.type_signature.member)
        def finalizer_computation(grouped_aggregation_output):

            # One minor downside of grouping the aggregation processes is that the
            # SecAgg measurements (e.g., clipped_count) are computed at a group level
            # (a group means all metric values belonging to the same `factory_key`).
            secure_sum_measurements = collections.OrderedDict(
                (factory_key, output.measurements)
                for factory_key, output in grouped_aggregation_output.items())
            finalized_metrics = collections.OrderedDict(
                secure_sum_measurements=secure_sum_measurements)
            grouped_unfinalized_metrics = collections.OrderedDict(
                (factory_key, output.result)
                for factory_key, output in grouped_aggregation_output.items())
            flattened_unfinalized_metrics_list = flatten_grouped_values(
                grouped_unfinalized_metrics)
            unfinalized_metrics = tf.nest.pack_sequence_as(
                structure_of_tensor_types, flattened_unfinalized_metrics_list)
            for metric_name, metric_finalizer in metric_finalizers.items():
                finalized_metrics[metric_name] = metric_finalizer(
                    unfinalized_metrics[metric_name])
            return finalized_metrics

        return intrinsics.federated_map(finalizer_computation,
                                        metrics_aggregation_output)

    return aggregator_computation
Exemplo n.º 20
0
def flatten_with_tuple_paths(structure):
  return dm_tree.flatten_with_path(structure)
Exemplo n.º 21
0
 def testFlattenWithPath(self, inputs, expected):
     self.assertEqual(tree.flatten_with_path(inputs), expected)
Exemplo n.º 22
0
 def _get_path_for_column_index(self, column_index):
     i = self._column_index_to_flat_structure_index[column_index]
     return tree.flatten_with_path(self._structure)[i][0]
Exemplo n.º 23
0
def main(_):
    problem_config = FLAGS.problem_config

    # Load the offline dataset and environment.
    _, _, environment = utils.load_data_and_env(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        batch_size=1)
    environment_spec = specs.make_environment_spec(environment)

    # Load pretrained target policy network.
    policy_net = utils.load_policy_net(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        environment_spec=environment_spec)

    actor = actors.FeedForwardActor(policy_network=policy_net)

    logger = loggers.TerminalLogger('ground_truth')

    discount = problem_config['discount']

    returns = []
    lengths = []

    t_start = time.time()
    timestep = environment.reset()
    actor.observe_first(timestep)
    cur_return = 0.
    cur_step = 0
    while len(returns) < FLAGS.num_episodes:

        action = actor.select_action(timestep.observation)
        timestep = environment.step(action)
        # Have the agent observe the timestep and let the actor update itself.
        actor.observe(action, next_timestep=timestep)

        cur_return += pow(discount, cur_step) * timestep.reward
        cur_step += 1

        if timestep.last():
            # Append return of the current episode, and reset the environment.
            returns.append(cur_return)
            lengths.append(cur_step)
            timestep = environment.reset()
            actor.observe_first(timestep)
            cur_return = 0.
            cur_step = 0

            if len(returns) % (FLAGS.num_episodes // 10) == 0:
                print(
                    f'Run time {time.time() - t_start:0.0f} secs, '
                    f'evaluated episode {len(returns)} / {FLAGS.num_episodes}')

    # Returned data include problem configs.
    results = {
        '_'.join(keys): value
        for keys, value in tree.flatten_with_path(problem_config)
    }

    # And computed results.
    results.update({
        'metric_value':
        np.mean(returns),
        'metric_std_dev':
        np.std(returns, ddof=0),
        'metric_std_err':
        np.std(returns, ddof=0) / np.sqrt(len(returns)),
        'length_mean':
        np.mean(lengths),
        'length_std':
        np.std(lengths, ddof=0),
        'num_episodes':
        len(returns),
    })
    logger.write(results)
Exemplo n.º 24
0
def flatten_with_joined_string_paths(structure, separator='/'):
    """Replacement for deprecated tf.nest.flatten_with_joined_string_paths."""
    return [(separator.join(map(str, path)), item)
            for path, item in tree.flatten_with_path(structure)]
Exemplo n.º 25
0
Arquivo: dot.py Projeto: ibab/haiku
def _graph_to_dot(graph: Graph, args, outputs):
    """Converts from an internal graph IR to 'dot' format."""

    lines = [
        'digraph G {',
        'rankdir = TD;',
        'compound = true;',
        f'label = <<b>{graph.title}</b>>;',
        'labelloc = t;',
        'stylesheet = <',
        '  data:text/css,',
        '  @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);',
        '  svg text {',
        '    font-family: \'Roboto\';',
        '  }',
        '  .node text {',
        '    font-size: 12px;',
        '  }'
        '>',
    ]

    def format_path(path):
        if isinstance(outputs, tuple):
            out = f'output[{path[0]}]'
            if len(path) > 1:
                out += ': ' + '/'.join(map(str, path[1:]))
        else:
            out = 'output'
            if path:
                out += ': ' + '/'.join(map(str, path))
        return out

    used_argids = set()
    argid_usecount = collections.Counter()
    op_outids = set()
    captures = []
    argids = {id(v) for v in jax.tree_leaves(args)}
    outids = {id(v) for v in jax.tree_leaves(outputs)}
    outname = {
        id(v): format_path(p)
        for p, v in tree.flatten_with_path(outputs)
    }

    def render_graph(g: Graph, parent: Optional[Graph] = None):
        """Renders a given graph by appending 'dot' format lines."""

        if parent:
            lines.extend([
                f'subgraph cluster_{id(g)} {{',
                '  style="rounded,filled";',
                '  fillcolor="#F0F5F5";',
                '  color="#14234B;";',
                '  fontsize=14;',
                f'  label = <<b>{g.title}</b>>;',
                '  labelloc = t;',
            ])

        for node in g.nodes:
            label = f'<b>{node.title}</b>'
            for o in node.outputs:
                label += '<br/>' + _format_val(o)
                op_outids.add(id(o))

            node_id = id(node.id)
            if node_id in outids:
                label = f'<b>{outname[node_id]}</b><br/>' + label
                color = '0053D6'
                fillcolor = 'AABFFF'
                style = 'filled,bold'
            else:
                color = 'FFDB13'
                fillcolor = 'FFF26E'
                style = 'filled'
            lines.append(f'{node_id} [label=<{label}>, '
                         ' shape=rect,'
                         f' style="{style}",'
                         ' tooltip=" ",'
                         ' fontcolor="black",'
                         f' color="#{color}",'
                         f' fillcolor="#{fillcolor}"];')

        for s in g.subgraphs:
            render_graph(s, parent=g)

        if parent:
            lines.append(f'}}  // subgraph cluster_{id(g)}')

        for a, b in g.edges:
            if id(a) not in argids and id(a) not in op_outids:
                captures.append(a)

            a, b = map(id, (a, b))
            if a in argids:
                i = argid_usecount[a]
                argid_usecount[a] += 1
                lines.append(f'{a}{i} -> {b};')
            else:
                lines.append(f'{a} -> {b};')
            used_argids.add(a)

    render_graph(graph, parent=None)

    # Process inputs and label them in the graph.
    for path, value in tree.flatten_with_path(args):
        if value is None:
            continue

        node_id = id(value)
        if node_id not in used_argids:
            continue

        for i in range(argid_usecount[node_id]):
            label = f'<b>args[{path[0]}]'
            if len(path) > 1:
                label += ': ' + '/'.join(map(str, path[1:]))
            label += '</b>'
            if hasattr(value, 'shape') and hasattr(value, 'dtype'):
                label += f'<br/>{_format_val(value)}'
            style = 'filled'
            fillcolor = '#FFDEAF'
            fontcolor = 'black'

            if i > 0:
                label = '<b>(reuse)</b><br/>' + label
                style += ',dotted'
                fillcolor = '#FFEACC'
                fontcolor = '#565858'

            lines.append(f'{node_id}{i} [label=<{label}>'
                         ' shape=rect,'
                         f' style="{style}",'
                         f' fontcolor="{fontcolor}",'
                         ' color="#FF8A4F",'
                         f' fillcolor="{fillcolor}"];')

    for value in captures:
        node_id = id(value)
        if value.size == 1:
            label = f'<b>{value.item()}</b>'
        else:
            label = f'<b>{_format_val(value)}</b>'

        lines.append(f'{node_id} [label=<{label}>'
                     ' shape=rect,'
                     ' style="filled",'
                     ' fontcolor="black",'
                     ' color="#A261FF",'
                     ' fillcolor="#E6D6FF"];')

    lines.append('} // digraph G')
    return '\n'.join(lines) + '\n'