def string_list_all_gather(strings: List[str]) -> List[str]: """ Utility function for distributed data parallel to all gather a list of strings. Note that if the item in `strings` is longer than 1024 chars, it will be truncated to 1024: https://github.com/pytorch/ignite/blob/master/ignite/distributed/comp_models/base.py#L92 Args: strings: a list of strings to all gather. """ world_size = idist.get_world_size() if world_size <= 1: return strings result: List[List[str]] = [[] for _ in range(world_size)] # get length of strings length = len(strings) all_lens = idist.all_gather(length) max_len = max(all_lens) # pad the item to make sure the same length if length < max_len: strings = strings + ["" for _ in range(max_len - length)] if get_torch_version_tuple() > (1, 6, 0): for s in strings: gathered = idist.all_gather(s) for i, g in enumerate(gathered): if len(g) > 0: result[i].append(g) else: raise RuntimeError( "string all_gather can not be supported in PyTorch < 1.7.0.") return [i for k in result for i in k]
def __init__( self, nn_module, target_layer_names: Union[str, Sequence[str]], register_forward: bool = False, register_backward: bool = False, ): """ Args: nn_module: the model to be wrapped. target_layer_names: the names of the layer to cache. register_forward: whether to cache the forward pass output corresponding to `target_layer_names`. register_backward: whether to cache the backward pass output corresponding to `target_layer_names`. """ self.model = nn_module self.target_layers = ensure_tuple(target_layer_names) self.gradients: Dict[str, torch.Tensor] = {} self.activations: Dict[str, torch.Tensor] = {} self.score = None self.class_idx = None self.register_backward = register_backward self.register_forward = register_forward _registered = [] for name, mod in nn_module.named_modules(): if name not in self.target_layers: continue _registered.append(name) if self.register_backward: if get_torch_version_tuple() < (1, 8): mod.register_backward_hook(self.backward_hook(name)) else: if "inplace" in mod.__dict__ and mod.__dict__["inplace"]: # inplace=True causes errors for register_full_backward_hook mod.__dict__["inplace"] = False mod.register_full_backward_hook(self.backward_hook(name)) if self.register_forward: mod.register_forward_hook(self.forward_hook(name)) if len(_registered) != len(self.target_layers): warnings.warn( f"Not all target_layers exist in the network module: targets: {self.target_layers}." )
def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[str]: """ Utility function for distributed data parallel to all gather a list of strings. Args: strings: a list of strings to all gather. delimiter: use the delimiter to join the string list to be a long string, then all gather across ranks and split to a list. default to "\t". """ if idist.get_world_size() <= 1: return strings _joined = delimiter.join(strings) if get_torch_version_tuple() > (1, 6, 0): # all gather across all ranks _joined = delimiter.join(idist.all_gather(_joined)) else: raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.") return _joined.split(delimiter)
def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): expected = [102, 202, 302, 402, 502, 602, 702, 802, 902, 1002] _kwg = { "persistent_workers": persistent_workers } if get_torch_version_tuple() > (1, 7) else {} data_list = list(range(1, 11)) dataset = CacheDataset(data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False) self.assertListEqual(expected, list(dataset)) loader = DataLoader( CacheDataset( data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False, ), batch_size=1, num_workers=loader_workers, **_kwg, ) self.assertListEqual(expected, [y.item() for y in loader]) self.assertListEqual(expected, [y.item() for y in loader]) dataset = SmartCacheDataset( data=data_list, transform=_StatefulTransform(), cache_rate=0.7, replace_rate=0.5, num_replace_workers=cache_workers, progress=False, shuffle=False, ) self.assertListEqual(expected[:7], list(dataset)) loader = DataLoader( SmartCacheDataset( data=data_list, transform=_StatefulTransform(), cache_rate=0.7, replace_rate=0.5, num_replace_workers=cache_workers, progress=False, shuffle=False, ), batch_size=1, num_workers=loader_workers, **_kwg, ) self.assertListEqual(expected[:7], [y.item() for y in loader]) self.assertListEqual(expected[:7], [y.item() for y in loader]) with tempfile.TemporaryDirectory() as tempdir: pdata = PersistentDataset(data=data_list, transform=_StatefulTransform(), cache_dir=tempdir) self.assertListEqual(expected, list(pdata)) loader = DataLoader( PersistentDataset(data=data_list, transform=_StatefulTransform(), cache_dir=tempdir), batch_size=1, num_workers=loader_workers, shuffle=False, **_kwg, ) self.assertListEqual(expected, [y.item() for y in loader]) self.assertListEqual(expected, [y.item() for y in loader])