def param_reduce(model, log=False): """Return a dict containing param counts per submodule.""" params = get_params(model) sizes = collections.defaultdict(int) for path, x in tree.flatten_with_path(params): size = x.size for i in range(len(path)): k = path[:i] sizes[k] += size for k in sorted(sizes): if log: logging.info('%s: %s', k, sizes[k]) return sizes
def fast_map_structure_with_path(func, *structure): """Faster map_structure_with_path implementation.""" head_entries_with_path = tree.flatten_with_path(structure[0]) if len(structure) > 1: tail_entries = (tree.flatten(s) for s in structure[1:]) entries_with_path = [ e[0] + e[1:] for e in zip(head_entries_with_path, *tail_entries) ] else: entries_with_path = head_entries_with_path # Arbitrarily choose one of the structures of the original sequence (the last) # to match the structure for the flattened sequence. return tree.unflatten_as(structure[-1], [func(*x) for x in entries_with_path])
def loss_fn(forward, params, state, batch, l2=True): """Computes a regularized loss for the given batch.""" logits, state = forward.apply( params, state, None, batch, is_training=True) labels = jax.nn.one_hot(batch[1], CLASS_NUM) logits = logits.reshape(len(labels), CLASS_NUM) # match labels shape loss = optax.softmax_cross_entropy(logits=logits, labels=labels).mean() acc = (labels.argmax(1) == logits.argmax(1)).mean() if l2: l2_params = [p for ((mod_name, _), p) in tree.flatten_with_path(params) if 'batchnorm' not in mod_name] loss = loss + 5e-4 * l2_loss(l2_params) return loss, (loss, state, acc)
def loss_fn( params: hk.Params, state: hk.State, batch: dataset.Batch, ) -> Tuple[jnp.ndarray, hk.State]: """Computes a regularized loss for the given batch.""" logits, state = forward.apply(params, state, None, batch, is_training=True) labels = jax.nn.one_hot(batch['labels'], 1000) cat_loss = jnp.mean(softmax_cross_entropy(logits=logits, labels=labels)) l2_params = [p for ((mod_name, _), p) in tree.flatten_with_path(params) if 'batchnorm' not in mod_name] reg_loss = FLAGS.train_weight_decay * l2_loss(l2_params) loss = cat_loss + reg_loss return loss, state
def update_table(self, debug_namedtuple): nests = { 'obs': debug_namedtuple.observations, 'sample': debug_namedtuple.samples, 'gamestate': debug_namedtuple.gamestate, 'dist': debug_namedtuple.distances } path_delim = '.' df = None for nest_name, nest in nests.items(): # E.g., skip dist for controller_heads that do not yet provide controller component-wise distance if nest_name == 'dist' and not tf.nest.is_nested(nest): continue for path, array in tree.flatten_with_path(nest): leaf_name = path_delim.join([str(subpath) for subpath in path]) wide_single_df = pd.DataFrame( tf.squeeze(array).numpy().astype(float)) wide_single_df['Time'] = wide_single_df.index col_nm = f'{nest_name} {leaf_name}' long_df = pd.melt(wide_single_df, id_vars=['Time'], var_name='Batch', value_name=col_nm) if df is not None: df[col_nm] = long_df[col_nm] else: df = long_df df['Time'] += self.time self.time += len(wide_single_df) df['Loss'] = df[[col for col in df.columns if 'dist ' in col]].agg(np.nansum, axis='columns') self.mdf = self.mdf.append(df) done = False if len(self.mdf) > self.table_length: out_fn = self.saved_model_path + f'{self.table_name}.csv' self.mdf.to_csv(out_fn) print(f'Wrote {len(self.mdf)} frames to {out_fn}') done = True return done
def flatten_with_name(structure: Any) -> List[Tuple[str, Any]]: """Creates a flattened representation of the `structure` with names. Args: structure: A potentially nested structure. Returns: A `list` of `(name, value)` `tuples` representing the flattened `structure`, where `name` uniquely identifies the position of the `value` in the `structure`. """ flattened = tree.flatten_with_path(structure) def _name(path: Iterable[Any]) -> str: return '/'.join(map(str, path)) return [(_name(path), value) for path, value in flattened]
def _flatten_nested_dict(struct: Dict[str, Any]) -> Dict[str, Any]: """Flattens a given nested structure of tensors, sorting by flattened keys. For example, if we have the nested dictionary {'d':3, 'a': {'b': 1, 'c':2}, }, this will produce the (ordered) dictionary {'a/b': 1, 'a/c': 2, 'd': 3}. This will unpack lists, so that {'a': [3, 4, 5]} will be flattened to the ordered dictionary {'a/0': 3, 'a/1': 4, 'a/2': 5}. The resulting values of the flattened dictionary will be the leaf nodetensors in the original struct. Args: struct: A nested dictionary. Returns: A `collections.OrderedDict` representing a flattened version of `struct`. """ flat_struct = tree.flatten_with_path(struct) flat_struct = [('/'.join(map(str, path)), item) for path, item in flat_struct] return collections.OrderedDict(sorted(flat_struct))
def loss_fn( params: hk.Params, state: hk.State, loss_scale: jmp.LossScale, batch: dataset.Batch, ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, hk.State]]: """Computes a regularized loss for the given batch.""" logits, state = forward.apply(params, state, None, batch, is_training=True) labels = jax.nn.one_hot(batch['labels'], 1000) if FLAGS.train_smoothing: labels = optax.smooth_labels(labels, FLAGS.train_smoothing) loss = optax.softmax_cross_entropy(logits=logits, labels=labels).mean() l2_params = [ p for ((mod_name, _), p) in tree.flatten_with_path(params) if 'batchnorm' not in mod_name ] loss = loss + FLAGS.train_weight_decay * l2_loss(l2_params) return loss_scale.scale(loss), (loss, state)
def eval_hook(generator, discriminator, server_state, round_num): """Called during TFF GAN IterativeProcess, to compute eval metrics.""" start_time = time.time() metrics = {} gen_inputs = next(gen_inputs_iter) real_images = next(real_images_iter) if round_num % rounds_per_save_images == 0: _save_images( generator, gen_inputs, outdir=path_to_output_images, file_prefix='emnist_tff_gen_images_step_{:05d}'.format( round_num)) # Compute eval metrics. eval_metrics = _compute_eval_metrics(generator, discriminator, gen_inputs, real_images, gan_loss_fns, emnist_classifier_for_metrics) metrics['eval'] = eval_metrics # Get counters from the server_state. metrics['counters'] = server_state.counters # Write metrics to a tf.summary logdir. flat_metrics = tree.flatten_with_path(metrics) flat_metrics = [('/'.join(map(str, path)), item) for path, item in flat_metrics] flat_metrics = collections.OrderedDict(flat_metrics) with summary_writer.as_default(): for name, value in flat_metrics.items(): tf.summary.scalar(name, value, step=round_num) # Print out the counters, and log how long it took to compute/write metrics. for k, v in server_state.counters.items(): print('{:>40s} {:8.0f}'.format(k, v), flush=True) logging.info('Doing evaluation took %.2f seconds.', time.time() - start_time)
def flatten_with_joined_string_paths(structure, separator='/'): """Returns a list of (string path, data element) tuples. The order of tuples produced matches that of `nest.flatten`. This allows you to flatten a nested structure while keeping information about where in the structure each data element was located. See `nest.yield_flat_paths` for more information. Args: structure: the nested structure to flatten. separator: string to separate levels of hierarchy in the results, defaults to '/'. Returns: A list of (string, data element) tuples. """ def stringify_and_join(path_elements): return separator.join(str(path_element) for path_element in path_elements) return [(stringify_and_join(pe), v) for pe, v in dm_tree.flatten_with_path(structure)]
def __call__(self, train_metrics, eval_metrics, round_num): """A function suitable for passing as an eval hook to the training_loop. Args: train_metrics: A `dict` of training metrics computed in TFF. eval_metrics: A `dict` of evalutation metrics computed in TFF. round_num: The current round number. """ metrics = { 'train': train_metrics, 'eval': eval_metrics, 'round': round_num, } flat_metrics = tree.flatten_with_path(metrics) flat_metrics = [('/'.join(map(str, path)), item) for path, item in flat_metrics] flat_metrics = collections.OrderedDict(flat_metrics) logging.info('Evaluation at round {:d}:\n{!s}'.format( round_num, pprint.pformat(flat_metrics))) # Also write metrics to a tf.summary logdir with self._summary_writer.as_default(): for name, value in flat_metrics.items(): tf.compat.v2.summary.scalar(name, value, step=round_num) if tf.io.gfile.exists(self._results_file): metrics = pd.read_csv(self._results_file, header=0, index_col=0, engine='c') # Remove everything after `round_num`, in case the experiment was # restarted at an earlier checkpoint we want to avoid duplicate metrics. metrics = metrics[:round_num] metrics = metrics.append(flat_metrics, ignore_index=True) else: metrics = pd.DataFrame(flat_metrics, index=[0]) utils_impl.atomic_write_to_csv(metrics, self._results_file)
def append(self, data: Any, *, partial_step: bool = False): """Columnwise append of data leaf nodes to internal buffers. If `data` includes fields or sub structures which haven't been present in any previous calls then the types and shapes of the new fields are extracted and used to validate future `append` calls. The structure of `history` is also updated to include the union of the structure across all `append` calls. When new fields are added after the first step then the newly created history field will be filled with `None` in all preceding positions. This results in the equal indexing across columns. That is `a[i]` and `b[i]` references the same step in the sequence even if `b` was first observed after `a` had already been seen. It is possible to create a "step" using more than one `append` call by setting the `partial_step` flag. Partial steps can be used when some parts of the step becomes available only as a result of inserting (and learning from) trajectories that include the fields available first (e.g learn from the SARS trajectory to select the next action in an on-policy agent). In the final `append` call of the step, `partial_step` must be set to False. Failing to "close" the partial step will result in error as the same field must NOT be provided more than once in the same step. Args: data: The (possibly nested) structure to make available for new items to reference. partial_step: If `True` then the step is not considered "done" with this call. See above for more details. Defaults to `False`. Returns: References to the data structured just like provided `data`. Raises: ValueError: If the same column is provided more than once in the same step. """ # Unless it is the first step, check that the structure is the same. if self._structure is None: self._update_structure(tree.map_structure(lambda _: None, data)) data_with_path_flat = tree.flatten_with_path(data) try: # Use our custom mapping to flatten the expanded structure into columns. flat_column_data = self._reorder_like_flat_structure( data_with_path_flat) except KeyError: # `data` contains fields which haven't been observed before so we need # expand the spec using the union of the history and `data`. self._update_structure( _tree_union(self._structure, tree.map_structure(lambda x: None, data))) flat_column_data = self._reorder_like_flat_structure( data_with_path_flat) # If the last step is still open then verify that already populated columns # are None in the new `data`. if self._last_step_is_open: for i, (column, column_data) in enumerate( zip(self._column_history, flat_column_data)): if column_data is None or column.can_set_last: continue raise ValueError( f'Field {self._get_path_for_column_index(i)} has already been set ' f'in the active step by previous (partial) append call and thus ' f'must be omitted or set to None but got: {column_data}') # Flatten the data and pass it to the C++ writer for column wise append. In # all columns where data is provided (i.e not None) will return a reference # to the data (`pybind.WeakCellRef`) which is used to define trajectories # for `create_item`. The columns which did not receive a value (i.e None) # will return None. if partial_step: flat_column_data_references = self._writer.AppendPartial( flat_column_data) else: flat_column_data_references = self._writer.Append(flat_column_data) # Append references to respective columns. Note that we use the expanded # structure in order to populate the columns missing from the data with # None. for column, data_reference in zip(self._column_history, flat_column_data_references): # If the last step is still open (i.e `partial_step` was set) then we # populate that step instead of creating a new one. if not self._last_step_is_open: column.append(data_reference) elif data_reference is not None: column.set_last(data_reference) # Save the flag so the next `append` call either populates the same step # or begins a new step. self._last_step_is_open = partial_step # Unpack the column data into the expanded structure. expanded_structured_data_references = self._unflatten( flat_column_data_references) # Return the referenced structured in the same way as `data`. If only a # subset of the fields were present in the input data then only these fields # will exist in the output. filtered_data_references_flat = _tree_filter( expanded_structured_data_references, data_with_path_flat) return tree.unflatten_as(data, filtered_data_references_flat)
def flatten_with_tuple_paths(structure, expand_composites=False): if expand_composites: raise NotImplementedError( '`expand_composites=True` is not supported in JAX.') return dm_tree.flatten_with_path(structure)
def testFlattenWithPath(self, test_type): self._init_testdata(test_type) self.assertEqual(tree.flatten_with_path(self.dcls_with_map), self.dcls_flattened_with_path)
def _flatten(self, data): flat_data = [None] * len(self._path_to_column_index) for path, value in tree.flatten_with_path(data): flat_data[self._path_to_column_index[path]] = value return flat_data
def flatten_with_tuple_paths(structure): return flatten_with_path(structure)
def testFlattenWithPathUpTo(self): def get_paths_and_values(shallow_tree, input_tree): path_value_pairs = tree.flatten_with_path_up_to( shallow_tree, input_tree) paths = [p for p, _ in path_value_pairs] values = [v for _, v in path_value_pairs] return paths, values # Shallow tree ends at scalar. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] shallow_tree = [[True, True], [False, True]] (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [(0, 0), (0, 1), (1, 0), (1, 1)]) self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) self.assertEqual(flattened_shallow_tree_paths, [(0, 0), (0, 1), (1, 0), (1, 1)]) self.assertEqual(flattened_shallow_tree, [True, True, False, True]) # Shallow tree ends at string. input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values( shallow_tree, input_tree) input_tree_flattened_paths = [ p for p, _ in tree.flatten_with_path(input_tree) ] input_tree_flattened = tree.flatten(input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [(0, 0), (0, 1, 0), (0, 1, 1, 0), (0, 1, 1, 1, 0)]) self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) self.assertEqual(input_tree_flattened_paths, [(0, 0, 0), (0, 0, 1), (0, 1, 0, 0), (0, 1, 0, 1), (0, 1, 1, 0, 0), (0, 1, 1, 0, 1), (0, 1, 1, 1, 0, 0), (0, 1, 1, 1, 0, 1)]) self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) # Make sure dicts are correctly flattened, yielding values, not keys. input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", ), ("b", ), ("d", 0), ("d", 1)]) self.assertEqual(input_tree_flattened_as_shallow_tree, [1, { "c": 2 }, 3, (4, 5)]) # Namedtuples. ab_tuple = collections.namedtuple("ab_tuple", "a, b") input_tree = ab_tuple(a=[0, 1], b=2) shallow_tree = ab_tuple(a=0, b=1) (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", ), ("b", )]) self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2]) # Nested dicts, OrderedDicts and namedtuples. input_tree = collections.OrderedDict([ ("a", ab_tuple(a=[0, { "b": 1 }], b=2)), ("c", { "d": 3, "e": collections.OrderedDict([("f", 4)]) }) ]) shallow_tree = input_tree (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", "a", 0), ("a", "a", 1, "b"), ("a", "b"), ("c", "d"), ("c", "e", "f")]) self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) shallow_tree = collections.OrderedDict([("a", 0), ("c", { "d": 3, "e": 1 })]) (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", ), ("c", "d"), ("c", "e")]) self.assertEqual(input_tree_flattened_as_shallow_tree, [ ab_tuple(a=[0, { "b": 1 }], b=2), 3, collections.OrderedDict([("f", 4)]) ]) shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", ), ("c", )]) self.assertEqual(input_tree_flattened_as_shallow_tree, [ ab_tuple(a=[0, { "b": 1 }], b=2), { "d": 3, "e": collections.OrderedDict([("f", 4)]) } ]) ## Shallow non-list edge-case. # Using iterable elements. input_tree = ["input_tree"] shallow_tree = "shallow_tree" (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = ["input_tree_0", "input_tree_1"] shallow_tree = "shallow_tree" (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Test case where len(shallow_tree) < len(input_tree) input_tree = {"a": "A", "b": "B", "c": "C"} shallow_tree = {"a": 1, "c": 2} with self.assertRaisesWithLiteralMatch( # pylint: disable=g-error-prone-assert-raises ValueError, tree._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( input_length=len(input_tree), shallow_length=len(shallow_tree))): get_paths_and_values(shallow_tree, input_tree) # Using non-iterable elements. input_tree = [0] shallow_tree = 9 (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = [0, 1] shallow_tree = 9 (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Both non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = "shallow_tree" (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = 0 shallow_tree = 0 (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Input non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = ["shallow_tree"] with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format( type(input_tree))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values( shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree_paths, [(0, )]) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = "input_tree" shallow_tree = ["shallow_tree_9", "shallow_tree_8"] with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format( type(input_tree))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values( shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree_paths, [(0, ), (1, )]) self.assertEqual(flattened_shallow_tree, shallow_tree) # Using non-iterable elements. input_tree = 0 shallow_tree = [9] with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format( type(input_tree))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values( shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree_paths, [(0, )]) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = 0 shallow_tree = [9, 8] with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format( type(input_tree))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values( shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree_paths, [(0, ), (1, )]) self.assertEqual(flattened_shallow_tree, shallow_tree)
def _graph_to_dot(graph: Graph, args, outputs) -> str: """Converts from an internal graph IR to 'dot' format.""" if tree is None: raise ImportError('hk.experimental.to_dot requires dm-tree>=0.1.1.') def format_path(path): if isinstance(outputs, tuple): out = f'output[{path[0]}]' if len(path) > 1: out += ': ' + '/'.join(map(str, path[1:])) else: out = 'output' if path: out += ': ' + '/'.join(map(str, path)) return out lines = [] used_argids = set() argid_usecount = collections.Counter() op_outids = set() captures = [] argids = {id(v) for v in jax.tree_leaves(args)} outids = {id(v) for v in jax.tree_leaves(outputs)} outname = { id(v): format_path(p) for p, v in tree.flatten_with_path(outputs) } def render_graph(g: Graph, parent: Optional[Graph] = None, depth: int = 0): """Renders a given graph by appending 'dot' format lines.""" if parent: lines.extend([ f'subgraph cluster_{id(g)} {{', ' style="rounded,filled";', ' fillcolor="#F0F5F5";', ' color="#14234B;";', ' pad=0.1;', f' fontsize={_scaled_font_size(depth)};', f' label = <<b>{escape(g.title)}</b>>;', ' labelloc = t;', ]) for node in g.nodes: label = f'<b>{escape(node.title)}</b>' for o in node.outputs: label += '<br/>' + _format_val(o) op_outids.add(id(o)) node_id = id(node.id) if node_id in outids: label = f'<b>{escape(outname[node_id])}</b><br/>' + label color = '#0053D6' fillcolor = '#AABFFF' style = 'filled,bold' else: color = '#FFDB13' fillcolor = '#FFF26E' style = 'filled' lines.append(f'{node_id} [label=<{label}>, ' f' id="node{node_id}",' ' shape=rect,' f' style="{style}",' ' tooltip=" ",' ' fontcolor="black",' f' color="{color}",' f' fillcolor="{fillcolor}"];') for s in g.subgraphs: render_graph(s, parent=g, depth=depth - 1) if parent: lines.append(f'}} // subgraph cluster_{id(g)}') for a, b in g.edges: if id(a) not in argids and id(a) not in op_outids: captures.append(a) a, b = map(id, (a, b)) if a in argids: i = argid_usecount[a] argid_usecount[a] += 1 lines.append(f'{a}{i} -> {b};') else: lines.append(f'{a} -> {b};') used_argids.add(a) graph_depth = _max_depth(graph) render_graph(graph, parent=None, depth=graph_depth) # Process inputs and label them in the graph. for path, value in tree.flatten_with_path(args): if value is None: continue node_id = id(value) if node_id not in used_argids: continue for i in range(argid_usecount[node_id]): label = f'<b>args[{escape(path[0])}]' if len(path) > 1: label += ': ' + '/'.join(map(str, path[1:])) label += '</b>' if hasattr(value, 'shape') and hasattr(value, 'dtype'): label += f'<br/>{escape(_format_val(value))}' fillcolor = '#FFDEAF' fontcolor = 'black' if i > 0: label = '<b>(reuse)</b><br/>' + label fillcolor = '#FFEACC' fontcolor = '#565858' lines.append(f'{node_id}{i} [label=<{label}>' f' id="node{node_id}{i}",' ' shape=rect,' ' style="filled",' f' fontcolor="{fontcolor}",' ' color="#FF8A4F",' f' fillcolor="{fillcolor}"];') for value in captures: node_id = id(value) if (not hasattr(value, 'aval') and hasattr(value, 'size') and value.size == 1): label = f'<b>{value.item()}</b>' else: label = f'<b>{escape(_format_val(value))}</b>' lines.append(f'{node_id} [label=<{label}>' ' shape=rect,' ' style="filled",' ' fontcolor="black",' ' color="#A261FF",' ' fillcolor="#E6D6FF"];') head = [ 'digraph G {', 'rankdir = TD;', 'compound = true;', f'label = <<b>{escape(graph.title)}</b>>;', f'fontsize={_scaled_font_size(graph_depth)};', 'labelloc = t;', 'stylesheet = <', ' data:text/css,', ' @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);', ' svg text {', ' font-family: \'Roboto\';', ' }', ' .node text {', ' font-size: 12px;', ' }', ] for node_id, use_count in argid_usecount.items(): if use_count == 1: continue # Add hover animation for reused args. for a in range(use_count): for b in range(use_count): if a == b: head.append(f'%23node{node_id}{a}:hover ' '{ stroke-width: 0.2em; }') else: head.append( f'%23node{node_id}{a}:hover ~ %23node{node_id}{b} ' '{ stroke-width: 0.2em; }') head.append('>') lines.append('} // digraph G') return '\n'.join(head + lines) + '\n'
def secure_sum_then_finalize( metric_finalizers: model_lib.MetricFinalizersType, local_unfinalized_metrics_type: computation_types.StructWithPythonType, metric_value_ranges: Optional[MetricValueRangeDict] = None ) -> computation_base.Computation: """Creates a TFF computation that aggregates metrics using secure summation. The returned federated TFF computation has the following type signature: ``` (local_unfinalized_metrics@CLIENTS -> <aggregated_metrics@SERVER, secure_sum_measurements@SERVER) ``` where the input is given by `tff.learning.Model.report_local_unfinalized_metrics()` at `CLIENTS`, and the first output (`aggregated_metrics`) is computed by first securely summing the unfinalized metrics from `CLIENTS`, followed by applying the finalizers at `SERVER`. The second output (`secure_sum_measurements`) is an `OrderedDict` that maps from `factory_key`s to the secure summation measurements (e.g. the number of clients gets clipped. See `tff.aggregators.SecureSumFactory` for details). A `factory_key` is uniquely defined by three scalars: lower bound, upper bound, and tensor dtype (denoted as datatype enum). Metric values of the same `factory_key` are grouped and aggegrated together (and hence, the `secure_sum_measurements` are also computed at a group level). Since secure summation works in fixed-point arithmetic space, floating point numbers must be encoding using integer quantization. By default, each tensor in `local_unfinalized_metrics_type` will be clipped to `[0, 2**20 - 1]` and encoded to integers inside `tff.aggregators.SecureSumFactory`. Callers can change this range by setting `metric_value_ranges`, which may be a partial tree matching the structure of `local_unfinalized_metrics_type`. Example partial value range specification: >>> finalizers = ... >>> metrics_type = tff.to_type(collections.OrderedDict( a=tff.types.TensorType(tf.int32), b=tff.types.TensorType(tf.float32), c=[tff.types.TensorType(tf.float32), tff.types.TensorType(tf.float32)]) >>> value_ranges = collections.OrderedDict( b=(0.0, 1.0), c=[None, (0.0, 1.0)]) >>> aggregator = tff.learning.metrics.secure_sum_then_finalize( finalizers, metrics_type, value_ranges) This sets the range of the *second* tensor of `b` in the dictionary, using the range for the first tensor, and the `a` tensor. Args: metric_finalizers: An `OrderedDict` of `string` metric names to finalizer functions returned by `tff.learning.Model.metric_finalizers()`. It should have the same keys (i.e., metric names) as the `OrderedDict` returned by `tff.learning.Model.report_local_unfinalized_metrics()`. A finalizer is a callable (typically `tf.function` or `tff.tf_computation` decoreated function) that takes in a metric's unfinalized values, and returns the finalized values. local_unfinalized_metrics_type: A `tff.types.StructWithPythonType` (with `OrderedDict` as the Python container) of a client's local unfinalized metrics. Let `local_unfinalized_metrics` be the output of `tff.learning.Model.report_local_unfinalized_metrics()`. Its type can be obtained by `tff.framework.type_from_tensors(local_unfinalized_metrics)`. metric_value_ranges: A `collections.OrderedDict` that matches the structure of `local_unfinalized_metrics_type` (a value for each `tff.types.TensorType` in the type tree). Each leaf in the tree should have a 2-tuple that defines the range of expected values for that variable in the metric. If the entire structure is `None`, a default range of `[0.0, 2.0**20 - 1]` will be applied to all variables. Each leaf may also be `None`, which will also get the default range; allowing partial user sepcialization. At runtime, values that fall outside the ranges specified at the leaves, those values will be clipped to within the range. Returns: A federated TFF computation that securely sums the unfinalized metrics from `CLIENTS`, and applies the correponding finalizers at `SERVER`. Raises: TypeError: If the inputs are of the wrong types. ValueError: If the keys (i.e., metric names) in `metric_finalizers` are not the same as those expected by `local_unfinalized_metrics_type`. """ check_metric_finalizers(metric_finalizers) check_local_unfinalzied_metrics_type(local_unfinalized_metrics_type) check_finalizers_matches_unfinalized_metrics( metric_finalizers, local_unfinalized_metrics_type) default_metric_value_ranges = create_default_secure_sum_quantization_ranges( local_unfinalized_metrics_type) if metric_value_ranges is None: metric_value_ranges = default_metric_value_ranges # Walk the incoming `metric_value_ranges` and `default_metric_value_ranges` # and fill in any missing ranges using the defaults. def fill_missing_values_with_defaults(default_values, user_values): if isinstance(default_values, collections.abc.Mapping): if user_values is None: user_values = {} return type(default_values)( (key, fill_missing_values_with_defaults(default_value, user_values.get(key))) for key, default_value in default_values.items()) elif isinstance(default_values, list): if user_values is None: user_values = [None] * len(default_values) return [ fill_missing_values_with_defaults(default_value, user_values[idx]) for idx, default_value in enumerate(default_values) ] elif user_values is None: return _MetricRange(*default_values) else: _check_range(user_values) return _MetricRange(*user_values) try: metric_value_ranges = fill_missing_values_with_defaults( default_metric_value_ranges, metric_value_ranges) except TypeError as e: raise TypeError('Failed to create encoding value range from: ' f'{metric_value_ranges}') from e # Create an aggregator factory for each unique value range, rather than each # leaf tensor (which could introduce a lot of duplication). aggregator_factories = { value_range: secure.SecureSumFactory(value_range.upper, value_range.lower) for value_range in set(tree.flatten(metric_value_ranges)) } # Construct a python container of `tff.TensorType` so we can traverse it in # parallel with the value ranges during AggregationProcess construction. # Otherwise we have a `tff.Type` but `metric_value_ranges` is a Python # container which are difficult to traverse in parallel. structure_of_tensor_types = type_conversions.structure_from_tensor_type_tree( lambda t: t, local_unfinalized_metrics_type) # We will construct groups of tensors with the same dtype and quantization # value range so that we can construct fewer aggregations-of-structures, # rather than a large structure-of-aggregations. Without this, the TFF # compiler pipeline results in large slow downs (see b/218312198). factory_key_by_path = collections.OrderedDict() value_range_by_factory_key = collections.OrderedDict() path_list_by_factory_key = collections.defaultdict(list) # Maintain a flattened list of paths. This is useful to flatten the aggregated # values, which will then be used by `tf.nest.pack_sequence_as`. flattened_path_list = [] for (path, tensor_spec), (_, value_range) in zip( tree.flatten_with_path(structure_of_tensor_types), tree.flatten_with_path(metric_value_ranges)): factory_key = _create_factory_key(value_range.lower, value_range.upper, tensor_spec.dtype) factory_key_by_path[path] = factory_key value_range_by_factory_key[factory_key] = value_range path_list_by_factory_key[factory_key].append(path) flattened_path_list.append(path) @tensorflow_computation.tf_computation(local_unfinalized_metrics_type) def group_value_by_factory_key(local_unfinalized_metrics): """Groups client local metrics into a map of `factory_key` to value list.""" # We cannot use `collections.defaultdict(list)` here because its result is # incompatible with `structure_from_tensor_type_tree`. value_list_by_factory_key = collections.OrderedDict() for path, value in tree.flatten_with_path(local_unfinalized_metrics): factory_key = factory_key_by_path[path] if factory_key in value_list_by_factory_key: value_list_by_factory_key[factory_key].append(value) else: value_list_by_factory_key[factory_key] = [value] return value_list_by_factory_key def flatten_grouped_values(value_list_by_factory_key): """Flatten the values in the same order as in `flattened_path_list`.""" value_by_path = collections.OrderedDict() for factory_key in value_list_by_factory_key: path_list = path_list_by_factory_key[factory_key] value_list = value_list_by_factory_key[factory_key] for path, value in zip(path_list, value_list): value_by_path[path] = value flattened_value_list = [ value_by_path[path] for path in flattened_path_list ] return flattened_value_list # Create a aggregation process for each factory key. aggregation_process_by_factory_key = collections.OrderedDict() # Construct a python container of `tff.TensorType` so we can traverse it and # create aggregation processes from the factories. tensor_type_list_by_factory_key = ( type_conversions.structure_from_tensor_type_tree( lambda t: t, group_value_by_factory_key.type_signature.result)) for factory_key, tensor_type_list in tensor_type_list_by_factory_key.items( ): value_range = value_range_by_factory_key[factory_key] aggregation_process_by_factory_key[ factory_key] = aggregator_factories.get(value_range).create( computation_types.to_type(tensor_type_list)) @federated_computation.federated_computation( computation_types.at_clients(local_unfinalized_metrics_type)) def aggregator_computation(client_local_unfinalized_metrics): unused_state = intrinsics.federated_value((), placements.SERVER) client_local_grouped_unfinalized_metrics = intrinsics.federated_map( group_value_by_factory_key, client_local_unfinalized_metrics) metrics_aggregation_output = collections.OrderedDict() for factory_key, process in aggregation_process_by_factory_key.items(): metrics_aggregation_output[factory_key] = process.next( unused_state, client_local_grouped_unfinalized_metrics[factory_key]) metrics_aggregation_output = intrinsics.federated_zip( metrics_aggregation_output) @tensorflow_computation.tf_computation( metrics_aggregation_output.type_signature.member) def finalizer_computation(grouped_aggregation_output): # One minor downside of grouping the aggregation processes is that the # SecAgg measurements (e.g., clipped_count) are computed at a group level # (a group means all metric values belonging to the same `factory_key`). secure_sum_measurements = collections.OrderedDict( (factory_key, output.measurements) for factory_key, output in grouped_aggregation_output.items()) finalized_metrics = collections.OrderedDict( secure_sum_measurements=secure_sum_measurements) grouped_unfinalized_metrics = collections.OrderedDict( (factory_key, output.result) for factory_key, output in grouped_aggregation_output.items()) flattened_unfinalized_metrics_list = flatten_grouped_values( grouped_unfinalized_metrics) unfinalized_metrics = tf.nest.pack_sequence_as( structure_of_tensor_types, flattened_unfinalized_metrics_list) for metric_name, metric_finalizer in metric_finalizers.items(): finalized_metrics[metric_name] = metric_finalizer( unfinalized_metrics[metric_name]) return finalized_metrics return intrinsics.federated_map(finalizer_computation, metrics_aggregation_output) return aggregator_computation
def flatten_with_tuple_paths(structure): return dm_tree.flatten_with_path(structure)
def testFlattenWithPath(self, inputs, expected): self.assertEqual(tree.flatten_with_path(inputs), expected)
def _get_path_for_column_index(self, column_index): i = self._column_index_to_flat_structure_index[column_index] return tree.flatten_with_path(self._structure)[i][0]
def main(_): problem_config = FLAGS.problem_config # Load the offline dataset and environment. _, _, environment = utils.load_data_and_env( task_name=problem_config['task_name'], noise_level=problem_config['noise_level'], near_policy_dataset=problem_config['near_policy_dataset'], dataset_path=FLAGS.dataset_path, batch_size=1) environment_spec = specs.make_environment_spec(environment) # Load pretrained target policy network. policy_net = utils.load_policy_net( task_name=problem_config['task_name'], noise_level=problem_config['noise_level'], near_policy_dataset=problem_config['near_policy_dataset'], dataset_path=FLAGS.dataset_path, environment_spec=environment_spec) actor = actors.FeedForwardActor(policy_network=policy_net) logger = loggers.TerminalLogger('ground_truth') discount = problem_config['discount'] returns = [] lengths = [] t_start = time.time() timestep = environment.reset() actor.observe_first(timestep) cur_return = 0. cur_step = 0 while len(returns) < FLAGS.num_episodes: action = actor.select_action(timestep.observation) timestep = environment.step(action) # Have the agent observe the timestep and let the actor update itself. actor.observe(action, next_timestep=timestep) cur_return += pow(discount, cur_step) * timestep.reward cur_step += 1 if timestep.last(): # Append return of the current episode, and reset the environment. returns.append(cur_return) lengths.append(cur_step) timestep = environment.reset() actor.observe_first(timestep) cur_return = 0. cur_step = 0 if len(returns) % (FLAGS.num_episodes // 10) == 0: print( f'Run time {time.time() - t_start:0.0f} secs, ' f'evaluated episode {len(returns)} / {FLAGS.num_episodes}') # Returned data include problem configs. results = { '_'.join(keys): value for keys, value in tree.flatten_with_path(problem_config) } # And computed results. results.update({ 'metric_value': np.mean(returns), 'metric_std_dev': np.std(returns, ddof=0), 'metric_std_err': np.std(returns, ddof=0) / np.sqrt(len(returns)), 'length_mean': np.mean(lengths), 'length_std': np.std(lengths, ddof=0), 'num_episodes': len(returns), }) logger.write(results)
def flatten_with_joined_string_paths(structure, separator='/'): """Replacement for deprecated tf.nest.flatten_with_joined_string_paths.""" return [(separator.join(map(str, path)), item) for path, item in tree.flatten_with_path(structure)]
def _graph_to_dot(graph: Graph, args, outputs): """Converts from an internal graph IR to 'dot' format.""" lines = [ 'digraph G {', 'rankdir = TD;', 'compound = true;', f'label = <<b>{graph.title}</b>>;', 'labelloc = t;', 'stylesheet = <', ' data:text/css,', ' @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);', ' svg text {', ' font-family: \'Roboto\';', ' }', ' .node text {', ' font-size: 12px;', ' }' '>', ] def format_path(path): if isinstance(outputs, tuple): out = f'output[{path[0]}]' if len(path) > 1: out += ': ' + '/'.join(map(str, path[1:])) else: out = 'output' if path: out += ': ' + '/'.join(map(str, path)) return out used_argids = set() argid_usecount = collections.Counter() op_outids = set() captures = [] argids = {id(v) for v in jax.tree_leaves(args)} outids = {id(v) for v in jax.tree_leaves(outputs)} outname = { id(v): format_path(p) for p, v in tree.flatten_with_path(outputs) } def render_graph(g: Graph, parent: Optional[Graph] = None): """Renders a given graph by appending 'dot' format lines.""" if parent: lines.extend([ f'subgraph cluster_{id(g)} {{', ' style="rounded,filled";', ' fillcolor="#F0F5F5";', ' color="#14234B;";', ' fontsize=14;', f' label = <<b>{g.title}</b>>;', ' labelloc = t;', ]) for node in g.nodes: label = f'<b>{node.title}</b>' for o in node.outputs: label += '<br/>' + _format_val(o) op_outids.add(id(o)) node_id = id(node.id) if node_id in outids: label = f'<b>{outname[node_id]}</b><br/>' + label color = '0053D6' fillcolor = 'AABFFF' style = 'filled,bold' else: color = 'FFDB13' fillcolor = 'FFF26E' style = 'filled' lines.append(f'{node_id} [label=<{label}>, ' ' shape=rect,' f' style="{style}",' ' tooltip=" ",' ' fontcolor="black",' f' color="#{color}",' f' fillcolor="#{fillcolor}"];') for s in g.subgraphs: render_graph(s, parent=g) if parent: lines.append(f'}} // subgraph cluster_{id(g)}') for a, b in g.edges: if id(a) not in argids and id(a) not in op_outids: captures.append(a) a, b = map(id, (a, b)) if a in argids: i = argid_usecount[a] argid_usecount[a] += 1 lines.append(f'{a}{i} -> {b};') else: lines.append(f'{a} -> {b};') used_argids.add(a) render_graph(graph, parent=None) # Process inputs and label them in the graph. for path, value in tree.flatten_with_path(args): if value is None: continue node_id = id(value) if node_id not in used_argids: continue for i in range(argid_usecount[node_id]): label = f'<b>args[{path[0]}]' if len(path) > 1: label += ': ' + '/'.join(map(str, path[1:])) label += '</b>' if hasattr(value, 'shape') and hasattr(value, 'dtype'): label += f'<br/>{_format_val(value)}' style = 'filled' fillcolor = '#FFDEAF' fontcolor = 'black' if i > 0: label = '<b>(reuse)</b><br/>' + label style += ',dotted' fillcolor = '#FFEACC' fontcolor = '#565858' lines.append(f'{node_id}{i} [label=<{label}>' ' shape=rect,' f' style="{style}",' f' fontcolor="{fontcolor}",' ' color="#FF8A4F",' f' fillcolor="{fillcolor}"];') for value in captures: node_id = id(value) if value.size == 1: label = f'<b>{value.item()}</b>' else: label = f'<b>{_format_val(value)}</b>' lines.append(f'{node_id} [label=<{label}>' ' shape=rect,' ' style="filled",' ' fontcolor="black",' ' color="#A261FF",' ' fillcolor="#E6D6FF"];') lines.append('} // digraph G') return '\n'.join(lines) + '\n'