def _sync_dist(self, dist_sync_fn=gather_all_tensors): input_dict = { attr: getattr(self, attr) for attr in self._reductions.keys() } for attr, reduction_fn in self._reductions.items(): # pre-concatenate metric states that are lists to reduce number of all_gather operations if reduction_fn == dim_zero_cat and isinstance( input_dict[attr], list) and len(input_dict[attr]) > 1: input_dict[attr] = [dim_zero_cat(input_dict[attr])] output_dict = apply_to_collection( input_dict, Tensor, dist_sync_fn, group=self.process_group, ) for attr, reduction_fn in self._reductions.items(): # pre-processing ops (stack or flatten for inputs) if isinstance(output_dict[attr][0], Tensor): output_dict[attr] = torch.stack(output_dict[attr]) elif isinstance(output_dict[attr][0], list): output_dict[attr] = _flatten(output_dict[attr]) assert isinstance(reduction_fn, Callable) or reduction_fn is None reduced = reduction_fn( output_dict[attr] ) if reduction_fn is not None else output_dict[attr] setattr(self, attr, reduced)
def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None: input_dict = {attr: getattr(self, attr) for attr in self._reductions} for attr, reduction_fn in self._reductions.items(): # pre-concatenate metric states that are lists to reduce number of all_gather operations if reduction_fn == dim_zero_cat and isinstance( input_dict[attr], list) and len(input_dict[attr]) > 1: input_dict[attr] = [dim_zero_cat(input_dict[attr])] output_dict = apply_to_collection( input_dict, Tensor, dist_sync_fn, group=process_group or self.process_group, ) for attr, reduction_fn in self._reductions.items(): # pre-processing ops (stack or flatten for inputs) if isinstance(output_dict[attr][0], Tensor): output_dict[attr] = torch.stack(output_dict[attr]) elif isinstance(output_dict[attr][0], list): output_dict[attr] = _flatten(output_dict[attr]) if not (callable(reduction_fn) or reduction_fn is None): raise TypeError('reduction_fn must be callable or None') reduced = reduction_fn( output_dict[attr] ) if reduction_fn is not None else output_dict[attr] setattr(self, attr, reduced)
def _sync_dist(self, dist_sync_fn=gather_all_tensors): input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} output_dict = apply_to_collection( input_dict, Tensor, dist_sync_fn, group=self.process_group, ) for attr, reduction_fn in self._reductions.items(): # pre-processing ops (stack or flatten for inputs) if isinstance(output_dict[attr][0], Tensor): output_dict[attr] = torch.stack(output_dict[attr]) elif isinstance(output_dict[attr][0], list): output_dict[attr] = _flatten(output_dict[attr]) assert isinstance(reduction_fn, Callable) or reduction_fn is None reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] setattr(self, attr, reduced)
def test_flatten_list(): """Check that _flatten utility function works as expected.""" inp = [[1, 2, 3], [4, 5], [6]] out = _flatten(inp) assert out == [1, 2, 3, 4, 5, 6]