def gather( self, dst: int = 0, out: Optional[torch.Tensor] = None, ) -> None: """ Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the sharded tensor. The API needs to be called on all ranks in SPMD fashion. All ranks should have the same ``dst``. ``out`` should be a tensor of the same size as the overall size of the sharded tensor on ``dst`` and ``None`` on all other ranks. Args: dst(int): The rank where full tensor is constructed. Default: 0 out (:class `torch.Tensor`, optional): The output full tensor. Must to be provided ONLY on ``dst`` rank. Default: ``None`` """ rank = dist.get_rank(self._process_group) full_size = self.metadata().size _validate_output_tensor_for_gather(rank, dst, full_size, out) local_shards = self.local_shards() world_size = dist.get_world_size(self._process_group) gathered_shards = [None] * world_size # will revise this part with CPU support and use dist.gather() # once NCCL support for gather() is ready # https://github.com/pytorch/pytorch/issues/66187 device = torch.device(f"cuda:{rank % world_size}") with torch.cuda.device(device): dist.all_gather_object( obj=local_shards, object_list=gathered_shards, group=self._process_group, ) if rank == dst: dims = len(full_size) for shards in gathered_shards: if shards is None: raise RuntimeError( 'Gathered shards cannot be None on dst rank {dst}' ) for shard in shards: metadata = shard.metadata tensor = shard.tensor out_narrow_view = out for dim in range(dims): out_narrow_view = out_narrow_view.narrow( dim, metadata.shard_offsets[dim], metadata.shard_lengths[dim], ) out_narrow_view.copy_(tensor)
def _init_from_local_shards( cls, local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False, ): # STEP 1: Validate the Shardmetadatas locally process_group = ( process_group if process_group is not None else distributed_c10d._get_default_group() ) current_rank = dist.get_rank(process_group) world_size = dist.get_world_size(process_group) local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None global_tensor_size = _flatten_tensor_size(global_size) if len(local_shards) > 0: local_sharded_tensor_metadata = \ build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group) # STEP 2. Validate metadata across ranks, and build a global sharded tensor # metadata by gathering local ShardedTensorMetadata gathered_metadatas: List[Optional[ShardedTensorMetadata]] = [] if world_size > 1: gathered_metadatas = [None for _ in range(world_size)] dist.all_gather_object( gathered_metadatas, local_sharded_tensor_metadata, group=process_group ) else: gathered_metadatas = [local_sharded_tensor_metadata] global_sharded_tensor_metadata = build_global_metadata(gathered_metadatas) tensor_properties = global_sharded_tensor_metadata.tensor_properties # STEP 3: Validation done, create the actual ShardedTensor and populate fields # prepare initialization spec = shard_spec._infer_sharding_spec_from_shards_metadata( global_sharded_tensor_metadata.shards_metadata ) sharded_tensor = cls.__new__(cls, spec, global_sharded_tensor_metadata.size, dtype=tensor_properties.dtype, layout=tensor_properties.layout, pin_memory=tensor_properties.pin_memory, requires_grad=tensor_properties.requires_grad) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) # attach local_shards to the ShardedTensor created sharded_tensor._local_shards = local_shards # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor
def _init_from_local_shards( cls, local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False, ): # STEP 1: Validate the Shardmetadatas locally process_group = (process_group if process_group is not None else distributed_c10d._get_default_group()) current_rank = dist.get_rank(process_group) world_size = dist.get_world_size(process_group) local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None local_shards_device = torch.device("cpu") global_tensor_size = _flatten_tensor_size(global_size) if len(local_shards) > 0: local_sharded_tensor_metadata, local_shards_device = \ build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group) # STEP 2. Validate metadata across ranks, and build a global sharded tensor # metadata by gathering local ShardedTensorMetadata gathered_metadatas = [None for _ in range(world_size)] if local_shards_device.type == "cuda": # with GPU/NCCL, we need to set a device for all_gather_object # to use as we need to know which device we should put the # serialized tensor on before the NCCL collective. with torch.cuda.device(local_shards_device): dist.all_gather_object(gathered_metadatas, local_sharded_tensor_metadata, group=process_group) else: dist.all_gather_object(gathered_metadatas, local_sharded_tensor_metadata, group=process_group) global_sharded_tensor_metadata = build_global_metadata( gathered_metadatas) # STEP 3: Validation done, create the actual ShardedTensor and populate fields # prepare initialization sharded_tensor = cls.__new__(cls) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) # add to metadata and local_shards sharded_tensor._metadata = global_sharded_tensor_metadata sharded_tensor._local_shards = local_shards # make a EnumerableShardingSpec for sharded tensors that initialized from this API. # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list. # see issue https://github.com/pytorch/pytorch/issues/67244 sharded_tensor._sharding_spec = EnumerableShardingSpec( global_sharded_tensor_metadata.shards_metadata) # run post initialization, i.e. map registration, rpc initialization sharded_tensor._post_init() return sharded_tensor
def _shard_tensor(tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None) -> ShardedTensor: """ Given a :class:`torch.Tensor`, it shards that tensor according to the provided ``sharding_spec``. ``src_rank`` denotes the source rank which would be used as the ground truth of the data which would be scattered as shards across the rest of the ranks. Args: tensor (:class:`torch.Tensor`): Tensor needs to be sharded. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how to shard the Tensor. Keyword args: src_rank (int, optional): The source rank which is used as the ground truth of the data for the parameter that would be sharded and scattered across the rest of the ranks. Default: 0. process_group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Returns: A :class:`ShardedTensor` sharded from the given tensor. .. warning:: Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is currently supported as the ``sharding_spec``. """ if not tensor.is_contiguous(): raise ValueError('input tensor is not a contiguous Tensor') pg = process_group if process_group is not None else distributed_c10d._get_default_group( ) world_size = dist.get_world_size(pg) current_rank = dist.get_rank(pg) # Validate src_rank and sharding_spec are same across all ranks. gathered_list = [None] * world_size dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) for idx, entry in enumerate(gathered_list): if src_rank != entry[0]: # type: ignore[index] raise ValueError( f'src_rank={src_rank} on rank: {current_rank} does not ' # type: ignore[index] f'match with src_rank={entry[0]} on rank: {idx}') if sharding_spec != entry[1]: # type: ignore[index] raise ValueError( f'sharding_spec={sharding_spec} on rank: {current_rank} does not ' # type: ignore[index] f'match with sharding_spec={entry[1]} on rank: {idx}') st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group) return st
def _verify_sequence_number_across_pg(self, pg, verify_pg): seq_num = pg._get_sequence_number_for_group() obj_list = [None for _ in range(dist.get_world_size(verify_pg))] # We use a separate pg to verify the sequence numbers, otherwise these # collectives will themselves increment the sequence number. dist.all_gather_object(obj_list, seq_num, group=verify_pg) self.assertEqual(len(set(obj_list)), 1) return obj_list[0]
def _validate(model, process_group, assert_fn): module_states = [param.detach().cpu() for param in model.parameters()] module_states.extend([buffer.detach().cpu() for buffer in model.buffers()]) world_size = dist.get_world_size(process_group) olist = [None for _ in range(world_size)] dist.all_gather_object(olist, module_states, group=process_group) rank0_states = olist[0] for state in olist[1:]: for p1, p2 in zip(rank0_states, state): assert_fn(p1, p2)
def all_gather_object(self, object: T) -> List[T]: """ Same as c10d::all_gather_object but works without distributed enabled. """ if self.use_dist: gather_objs = cast(List[T], [None] * dist.get_world_size(self.group)) dist.all_gather_object(object_list=gather_objs, obj=object, group=self.group) else: gather_objs = [object] return gather_objs
def all_gather(data): """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ world_size = get_world_size() if world_size == 1: return [data] data_list = [None] * world_size dist.all_gather_object(data_list, data) return data_list
def _test_sequence_num_set_default_pg(self, backend): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( backend, world_size=self.world_size, rank=self.rank, store=store, ) default_pg = c10d._get_default_group() seq_num = default_pg._get_sequence_number_for_group() obj_list = [None for _ in range(dist.get_world_size())] dist.all_gather_object(obj_list, seq_num) self.assertEqual(len(set(obj_list)), 1)
def _test_sequence_num_set_new_group(self, backend): store = c10d.FileStore(self.file_name, self.world_size) dist.init_process_group( backend, world_size=self.world_size, rank=self.rank, store=store, ) subgroup = dist.new_group([0, 1]) subgroup_seq = subgroup._get_sequence_number_for_group() obj_list = [None for _ in range(dist.get_world_size())] dist.all_gather_object(obj_list, subgroup_seq) self.assertEqual(len(set(obj_list)), 1)
def _test_sequence_num_incremented(self, process_group, ranks): # verify initial sequence numbers. Use a distinct process group for # verification to keep counts as expected with respect to process_group. verify_pg = dist.new_group( ranks=ranks, backend="gloo", ) assert dist.get_world_size(process_group) == dist.get_world_size( verify_pg) initial_num = ( self._verify_sequence_number_across_pg(pg=process_group, verify_pg=verify_pg) if not c10d.distributed_c10d._rank_not_in_group(process_group) else -1) # Verify sequence numbers are appropriately incremented for i in range(10): t = torch.ones(1, device=torch.cuda.current_device()) dist.all_reduce(t, group=process_group) if not c10d.distributed_c10d._rank_not_in_group(process_group): seq_num = self._verify_sequence_number_across_pg( pg=process_group, verify_pg=verify_pg, ) self.assertEqual(initial_num + i + 1, seq_num) if dist.get_world_size(process_group) > 2: # Test when certain ranks don't call collectives if dist.get_rank(process_group) not in [0, 2]: dist.all_reduce(t, group=process_group, async_op=True) # Now ranks 0 and 2 should be lagging by 1. if not c10d.distributed_c10d._rank_not_in_group(process_group): seq_num = process_group._get_sequence_number_for_group() rank = dist.get_rank(process_group) obj_list = [ None for _ in range(dist.get_world_size(verify_pg)) ] dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg) rank_to_seq_num = {rank: num for (rank, num) in obj_list} self.assertEqual(len(set(rank_to_seq_num.values())), 2) self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2]) expected_same = { rank_to_seq_num[i] for i in rank_to_seq_num.keys() if i not in [0, 2] } self.assertEqual(len(expected_same), 1) self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
def all_gather_object(data: Any) -> List[Any]: """ Run all_gather on arbitrary picklable data (not necessarily tensors) Note: For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device(). Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ if get_world_size() == 1: return [data] data_list = [None] * get_world_size() dist.all_gather_object(data_list, data) return data_list
def _tensorpipe_exchange_and_check_all_device_maps(my_name, my_device_count, my_device_maps, my_devices, group): gathered: List[Tuple[str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]]] = [("", 0, {}, []) for _ in range(group.size())] dist.all_gather_object( gathered, (my_name, my_device_count, my_device_maps, my_devices), group) all_names = [name for name, _, _, _ in gathered] all_device_counts = {name: count for name, count, _, _ in gathered} all_device_maps = {name: map_ for name, _, map_, _ in gathered} all_devices = {name: devices for name, _, _, devices in gathered} _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices) # passed all checked, construct reverse mapping and get list of devices handled by this agent reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps) return reverse_device_maps, my_devices
def _assert_module_states( model: nn.Module, process_group: dist.ProcessGroup, assert_fn: Callable, ): """ All-gathers module states across ranks and calls ``assert_fn`` on each pair of corresponding states from rank 0 and a nonzero rank. For example, if ``assert_fn`` is ``self.assertEqual()``, then this checks that all module states are equal across ranks. """ # Include names for debugging convenience named_module_states = [(param_name, param.detach().cpu()) for param_name, param in model.named_parameters()] named_module_states += [(buffer_name, buffer.detach().cpu()) for buffer_name, buffer in model.named_buffers()] world_size = dist.get_world_size(process_group) olist = [None for _ in range(world_size)] dist.all_gather_object(olist, named_module_states, group=process_group) rank0_states = olist[0] for state in olist[1:]: for (_, p1), (_, p2) in zip(rank0_states, state): assert_fn(p1, p2)
def all_gather(data, group=None): """ Run all_gather on arbitrary picklable data (not necessarily tensors). Args: data: any picklable object group: a torch process group. By default, will use a group which contains all ranks on gloo backend. Returns: list[data]: list of data gathered from each rank """ if get_world_size() == 1: return [data] if group is None: group = _get_global_gloo_group( ) # use CPU group by default, to reduce GPU RAM usage. world_size = dist.get_world_size(group) if world_size == 1: return [data] output = [None for _ in range(world_size)] dist.all_gather_object(output, data, group=group) return output
def main_per_process(rank, world_size, args): init_process(rank, world_size) start_epoch = 1 if args.wandb and rank == 0: run_name = get_run_name(args) wandb.init(project='myproject', entity='myaccount') wandb.run.name = run_name wandb.config.update(args) if rank == 0: output_cuda_info() # load dataset train_val_split = 0.2 batch_size_per_proc = int(args.batch_size / world_size) train_set, val_set, test_set = load_cifar10(train_val_split, args.pretrained) # create sampler for ddp train_sampler = torch.utils.data.distributed.DistributedSampler( train_set, num_replicas=world_size, rank=rank, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler( val_set, num_replicas=world_size, rank=rank, shuffle=False) # create data loader for ddp train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size_per_proc, num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size_per_proc, num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) # create ddp model model = make_model(args.model, 10, pretrained=args.pretrained, fix_param=args.fixparam) model = model.to(rank) # model = nn.SyncBatchNorm.convert_sync_batchnorm(model) ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[rank]) if rank == 0: output_summary(ddp_model, train_loader) # settings for training criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(ddp_model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epoch) # synchronize dist.barrier() # start training print(f'[{datetime.now()}]#{rank}: start training') for epoch in range(start_epoch, start_epoch + args.epoch): if rank == 0: print(f'Epoch[{epoch}/{args.epoch}]') train_sampler.set_epoch(epoch) val_sampler.set_epoch(epoch) dist.barrier() # synchronize # train and validate train_loss, train_acc = train_epoch(ddp_model, train_loader, optimizer, criterion, rank) val_loss, val_acc = validate_epoch(ddp_model, val_loader, criterion, rank) dist.barrier() # synchronize # sharing loss and accuracy among all gpus(processes) train_loss_list = [0.] * world_size train_acc_list = [0.] * world_size val_loss_list = [0.] * world_size val_acc_list = [0.] * world_size dist.all_gather_object(train_loss_list, train_loss) dist.all_gather_object(train_acc_list, train_acc) dist.all_gather_object(val_loss_list, val_loss) dist.all_gather_object(val_acc_list, val_acc) # save data to wandb if args.wandb and rank == 0: avg_train_loss = sum(train_loss_list) / world_size avg_train_acc = sum(train_acc_list) / world_size avg_val_loss = sum(val_loss_list) / world_size avg_val_acc = sum(val_acc_list) / world_size wandb.log({ 'acc': avg_train_acc, 'loss': avg_train_loss, 'val_acc': avg_val_acc, 'val_loss': avg_val_loss, 'lr': scheduler.get_last_lr()[0] }) scheduler.step() print(f'[{datetime.now()}]#{rank}: finished training') if rank == 0: print('# final test') test_loss, test_acc, class_acc = final_test(model, test_loader, criterion, rank) for key, value in class_acc.items(): print(f'{key} : {value: .3f}') # save data to wandb if args.wandb: wandb.log({'test_acc': test_acc, 'test_loss': test_loss}) wandb.finish() print('# all finished')
def _shard_tensor(tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None): """ Given a :class:`torch.Tensor`, it shards that tensor according to the provided ``sharding_spec``. ``src_rank`` denotes the source rank which would be used as the ground truth of the data which would be scattered as shards across the rest of the ranks. Args: tensor (:class:`torch.Tensor`): Tensor needs to be sharded. sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how to shard the Tensor. Keyword args: src_rank (int, optional): The source rank which is used as the ground truth of the data for the parameter that would be sharded and scattered across the rest of the ranks. Default: 0. process_group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Returns: A :class:`ShardedTensor` sharded from the given tensor. .. warning:: Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is currently supported as the ``sharding_spec``. """ if not isinstance(sharding_spec, ChunkShardingSpec): raise NotImplementedError('Only ChunkShardingspec is supported.') if not tensor.is_contiguous(): raise ValueError('input tensor is not a contiguous Tensor') pg = process_group if process_group is not None else distributed_c10d._get_default_group( ) world_size = dist.get_world_size(pg) rank = dist.get_rank(pg) # Validate src_rank and sharding_spec are same across all ranks. gathered_list = [None] * world_size dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) for idx, entry in enumerate(gathered_list): if src_rank != entry[0]: # type: ignore[index] raise ValueError( f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index] f'match with src_rank={entry[0]} on rank: {idx}') if sharding_spec != entry[1]: # type: ignore[index] raise ValueError( f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index] f'match with sharding_spec={entry[1]} on rank: {idx}') # Rearrange chunks according to placement. local_metadata = None current_offsets = [0] * len(tensor.size()) shards_metadata = [] sharding_dim_size = tensor.size( sharding_spec.dim) # type: ignore[arg-type] split_size = get_split_size(sharding_dim_size, world_size) tensor_sizes = list(tensor.size()) for idx, placement in enumerate(sharding_spec.placements): chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) shard_size = copy.deepcopy(tensor_sizes) shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index] shard_metadata = ShardMetadata( shard_offsets=copy.deepcopy(current_offsets), shard_sizes=shard_size, placement=placement, ) shards_metadata.append(shard_metadata) if rank == placement.rank(): # type: ignore[union-attr] local_metadata = shard_metadata current_offsets[ sharding_spec.dim] += chunked_dim_size # type: ignore[index] # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient). dist.broadcast(tensor, src=src_rank, group=pg) # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. local_shard = tensor.narrow( sharding_spec.dim, # type: ignore[arg-type] local_metadata.shard_offsets[ sharding_spec.dim], # type: ignore[union-attr, arg-type, index] local_metadata.shard_sizes[ sharding_spec.dim], # type: ignore[union-attr, index] ).clone().detach().contiguous() # Sync requires_grad to local_shard. local_shard.requires_grad = tensor.requires_grad # Create ShardedTensor based on local shards. local_shards = [ Shard( tensor=local_shard, metadata=local_metadata, # type: ignore[arg-type] ) ] return ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=pg)
def eval( model: torch.nn.Module, eval_loader: torch.utils.data.DataLoader, is_main: bool, previous_best: dict, img_size: List[int], save_dir: pathlib.Path = None, ) -> Tuple[dict, List[str]]: """ Evalulate the model against the evaulation set. Save the best weights if specified. Use the pycocotools package for metrics. Args: model: The model to evaluate. eval_loader: The eval dataset loader. previous_best: The current best eval metrics. save_dir: Where to save the model weights. Returns: The updated best metrics and a list of the metrics that improved. """ detections = [] labels = [] for images_batch, category_ids_batch, boxes_batch in eval_loader: # Send ground truth to BoundingBox for boxes, categories in zip(boxes_batch, category_ids_batch): image_boxes = [] for box, category in zip(boxes, categories.squeeze(0)): image_boxes.append( pascal_voc.BoundingBox(box / torch.Tensor(img_size * 2), 1.0, category.int().item())) labels.append(image_boxes) if torch.cuda.is_available(): images_batch = images_batch.cuda() if isinstance(model, parallel.DistributedDataParallel): detection_batch = model.module.get_boxes(images_batch) else: detection_batch = model.get_boxes(images_batch) detections.extend(detection_batch) if torch.distributed.is_initialized(): labels_list = [None] * distributed.get_world_size() detections_list = [None] * distributed.get_world_size() distributed.all_gather_object(detections_list, detections) distributed.all_gather_object(labels_list, labels) if is_main: if torch.distributed.is_initialized(): labels = [] for label_group in labels_list: labels.extend(label_group) detections = [] for detections_group in detections_list: detections.extend(detections_group) if isinstance(model, parallel.DistributedDataParallel): num_classes = model.module.num_classes else: num_classes = model.num_classes metrics = pascal_voc.compute_metrics(detections, labels, class_ids=list( range(num_classes))) # If there are the first results, set the previous to the current. previous_best = metrics if not previous_best else previous_best improved = [] for (metric, old), new in zip(previous_best.items(), metrics.values()): if new >= old: improved.append(metric) previous_best[metric] = new return previous_best, improved else: return None, None
def allgather_object(obj): out = [None for _ in range(dist.get_world_size())] dist.all_gather_object(out, obj) return out
import torch.distributed as dist if dist.get_rank() == 0: objects = ["f", 1] else: objects = [None, None] # ruleid: pickles-in-torch-distributed dist.broadcast_object_list(objects, src=0) # ruleid: pickles-in-torch-distributed dist.all_gather_object(output, gather_objects[dist.get_rank()]) # ruleid: pickles-in-torch-distributed dist.gather_object(gather_objects[dist.get_rank()], output if dist.get_rank() == 0 else None, dst=0) # ruleid: pickles-in-torch-distributed dist.scatter_object_list(output_list, objects, src=0)
def test_all_gather_object(self): output = [None] * dist.get_world_size() dist.all_gather_object(object_list=output, obj=self.rank) for i, v in enumerate(output): self.assertEqual(i, v, f"rank: {self.rank}")
def shard_parameter(module: torch.nn.Module, param_name: str, sharding_spec: ShardingSpec, src_rank=0, process_group=None): """ Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that module, it shards that parameter according to the provided ``sharding_spec``. ``src_rank`` denotes the source rank which would be used as the ground truth of the data which would be scattered as shards across the rest of the ranks. This method replaces ``module.param_name`` with a :class:`torch.distributed._sharded_tensor.ShardedTensor` Args: module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded. param_name (str): Name of the parameter of ``module`` that needs to be sharded. sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification describing how to shard the Tensor. Keyword args: src_rank (int, optional): The source rank which is used as the ground truth of the data for the parameter that would be sharded and scattered across the rest of the ranks. Default: 0. process_group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. .. warning:: Only :class:`torch.distributed._sharding_spec.ShardingSpec` is currently supported as the ``sharding_spec``. """ # Perform some validation first. if not isinstance(sharding_spec, ChunkShardingSpec): raise ValueError('Only ChunkShardingspec is supported.') if not hasattr(module, param_name): raise ValueError( f'module: {module} does not have parameter with name: {param_name}' ) tensor = getattr(module, param_name) if not isinstance(tensor, torch.Tensor): raise ValueError( f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}' ) if not tensor.is_contiguous(): raise ValueError(f'param: {param_name} is not a contiguous Tensor') pg = process_group if process_group is not None else distributed_c10d._get_default_group( ) world_size = dist.get_world_size(pg) rank = dist.get_rank(pg) # Validate src_rank and sharding_spec are same across all ranks. gathered_list = [None] * world_size with torch.cuda.device(tensor.device): dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) for idx, entry in enumerate(gathered_list): if src_rank != entry[0]: # type: ignore[index] raise ValueError( f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index] f'match with src_rank={entry[0]} on rank: {idx}') if sharding_spec != entry[1]: # type: ignore[index] raise ValueError( f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index] f'match with sharding_spec={entry[1]} on rank: {idx}') # Rearrange chunks according to placement. local_metadata = None current_offsets = [0] * len(tensor.size()) shards_metadata = [] sharding_dim_size = tensor.size( sharding_spec.dim) # type: ignore[arg-type] split_size = get_split_size(sharding_dim_size, world_size) tensor_sizes = list(tensor.size()) for idx, placement in enumerate(sharding_spec.placements): chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) shard_size = copy.deepcopy(tensor_sizes) shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index] shard_metadata = ShardMetadata( shard_offsets=copy.deepcopy(current_offsets), shard_lengths=shard_size, placement=placement, ) shards_metadata.append(shard_metadata) if rank == placement.rank(): # type: ignore[union-attr] local_metadata = shard_metadata current_offsets[ sharding_spec.dim] += chunked_dim_size # type: ignore[index] # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient). dist.broadcast(tensor, src=src_rank, group=pg) # Reshape to get shard for this rank. local_shard = tensor.narrow( sharding_spec.dim, # type: ignore[arg-type] local_metadata.shard_offsets[ sharding_spec.dim], # type: ignore[union-attr, arg-type, index] local_metadata.shard_lengths[ sharding_spec.dim], # type: ignore[union-attr, index] ).contiguous() # Create ShardedTensor based on local shards. local_shards = [ Shard( tensor=local_shard, metadata=local_metadata, # type: ignore[arg-type] ) ] sharded_tensor_metadata = ShardedTensorMetadata( shards_metadata=shards_metadata, size=tensor.size(), tensor_properties=TensorProperties( dtype=local_shard.dtype, layout=local_shard.layout, requires_grad=local_shard.requires_grad, memory_format=torch.contiguous_format, pin_memory=local_shard.is_pinned(), )) st = ShardedTensor._init_from_local_shards(local_shards, sharded_tensor_metadata, process_group=pg) # Manually set sharding_spec st._sharding_spec = sharding_spec # Replace param with ShardedTensor. # Need to delete the attribute first since param_name might be # torch.nn.Parameter and can't be replaced with ShardedTensor which is # not torch.nn.Parameter. delattr(module, param_name) # Now we can set the attribute appropriately. setattr(module, param_name, st)
def _tensorpipe_exchange_and_check_all_device_maps( my_name, my_device_count, my_device_maps, my_devices, group ): gathered: List[Tuple[ str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device] ]] = [("", 0, {}, []) for _ in range(group.size())] dist.all_gather_object( gathered, (my_name, my_device_count, my_device_maps, my_devices), group ) all_names = [name for name, _, _, _ in gathered] all_device_counts = {name: count for name, count, _, _ in gathered} all_device_maps = {name: map_ for name, _, map_, _ in gathered} all_devices = {name: devices for name, _, _, devices in gathered} for node in all_names: devices = all_devices[node] if len(set(devices)) != len(devices): raise ValueError( f"Node {node} has duplicated devices\n" f"devices = {devices}" ) if not _tensorpipe_validate_devices(devices, all_device_counts[node]): raise ValueError( f"Node {node} has devices with invalid indices\n" f"devices = {devices}\n" f"device count = {all_device_counts[node]}" ) for source_node in all_names: if not set(all_device_maps[source_node].keys()).issubset(all_names): raise ValueError( f"Node {source_node} has invalid target node names in its device maps\n" f"device maps = {all_device_maps[source_node].keys()}\n" f"node names = {all_names}" ) for target_node, map_ in all_device_maps[source_node].items(): if len(set(map_.values())) != len(map_): raise ValueError( f"Node {source_node} has duplicated target devices " f"in its device map for {target_node}\n" f"device map = {map_}" ) if all_devices[source_node]: if not set(map_.keys()).issubset(all_devices[source_node]): raise ValueError( f"Node {source_node} has unexpected source devices " f"in its device map for {target_node}\n" f"device map = {map_}\n" f"devices = {all_devices[source_node]}" ) elif not _tensorpipe_validate_devices( map_.keys(), all_device_counts[source_node] ): raise ValueError( f"Node {source_node} has source devices with invalid indices " f"in its device map for {target_node}\n" f"device map = {map_}\n" f"device count = {all_device_counts[source_node]}" ) if all_devices[target_node]: if not set(map_.values()).issubset(all_devices[target_node]): raise ValueError( f"Node {source_node} has unexpected target devices " f"in its device map for {target_node}\n" f"device map = {map_}\n" f"devices = {all_devices[target_node]}" ) elif not _tensorpipe_validate_devices( map_.values(), all_device_counts[target_node] ): raise ValueError( f"Node {source_node} has target devices with invalid indices " f"in its device map for {target_node}\n" f"device map = {map_}\n" f"device count = {all_device_counts[target_node]}" ) # passed all checked, construct reverse mapping for return values reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {} for node in all_names: if my_name in all_device_maps[node]: reverse_device_maps[node] = { v: k for k, v in all_device_maps[node][my_name].items() } if not my_devices: devices_set: Set[torch.device] = set() for _, map_ in my_device_maps.items(): devices_set.update(map_.keys()) for _, map_ in reverse_device_maps.items(): devices_set.update(map_.keys()) devices_set.discard(torch.device("cpu")) my_devices = list(devices_set) my_devices = sorted(my_devices, key=lambda d: d.index) return reverse_device_maps, my_devices
def eval_epoch_end(self, outputs, mode): # if user specifies one validation dataloader, then PTL reverts to giving a list of dictionary instead of a list of list of dictionary if isinstance(outputs[0], dict): outputs = [outputs] loss_list = [] sb_score_list = [] for dataloader_idx, output in enumerate(outputs): if dataloader_idx == 0: eval_loss = getattr(self, f'{mode}_loss').compute() else: eval_loss = getattr(self, f'{mode}_loss_{dataloader_idx}').compute() translations = list( itertools.chain(*[x['translations'] for x in output])) ground_truths = list( itertools.chain(*[x['ground_truths'] for x in output])) assert len(translations) == len(ground_truths) # Gather translations and ground truths from all workers tr_and_gt = [None for _ in range(self.world_size)] # we also need to drop pairs where ground truth is an empty string dist.all_gather_object( tr_and_gt, [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != '']) if self.global_rank == 0: _translations = [] _ground_truths = [] for rank in range(0, self.world_size): _translations += [t for (t, g) in tr_and_gt[rank]] _ground_truths += [g for (t, g) in tr_and_gt[rank]] if self.tgt_language in ['ja']: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="ja-mecab") elif self.tgt_language in ['zh']: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="zh") else: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="13a") # because the reduction op later is average (over word_size) sb_score = sacre_bleu.score * self.world_size dataset_name = "Validation" if mode == 'val' else "Test" logging.info( f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Set size: {len(translations)}" ) logging.info( f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Val Loss = {eval_loss}" ) logging.info( f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Sacre BLEU = {sb_score / self.world_size}" ) logging.info( f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Translation Examples:" ) for i in range(0, 3): ind = random.randint(0, len(translations) - 1) logging.info(" " + '\u0332'.join(f"Example {i}:")) logging.info(f" Prediction: {translations[ind]}") logging.info(f" Ground Truth: {ground_truths[ind]}") else: sb_score = 0.0 loss_list.append(eval_loss.cpu().numpy()) sb_score_list.append(sb_score) if dataloader_idx == 0: self.log(f"{mode}_loss", eval_loss, sync_dist=True) self.log(f"{mode}_sacreBLEU", sb_score, sync_dist=True) getattr(self, f'{mode}_loss').reset() else: self.log(f"{mode}_loss_dl_index_{dataloader_idx}", eval_loss, sync_dist=True) self.log(f"{mode}_sacreBLEU_dl_index_{dataloader_idx}", sb_score, sync_dist=True) getattr(self, f'{mode}_loss_{dataloader_idx}').reset() if len(loss_list) > 1: self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True) self.log(f"{mode}_sacreBLEU_avg", np.mean(sb_score_list), sync_dist=True)
def load_state_dict( state_dict: Dict[str, Any], storage_reader: StorageReader, process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, no_dist: bool = False ) -> None: """ Load a distributed state_dict in SPMD style. Each rank will try to read the least amount of data necessary to fullfill the requested `state_dict`. When loading ShardedTensor instances, each rank only reads data for their local shards. All tensors in ``state_dict`` must be allocated on their destination device prior to calling this function. All non-tensor data is loaded using `torch.load()` and modified in place on state_dict. Users must call `load_state_dict` on the root module to ensure load pos-processing and non-tensor data properly propagates. This function can be used for local inference and load a checkpoint produced by ``save_state_dict`` without having a process group initialized by passing ``no_dist=True`` and by using Tensors instead of ShardedTensors. Args: state_dict (Dict[str, Any]) : The state_dict to load. Note that this state dict will updated in places. storage_reader (StorageReader): StorageReader used to load data from. process_group (ProcessGroup): ProcessGroup to be used for cross-rank synchronization coordinator_rank (int): Rank to use to coordinate the checkpoint, rank0 is used by default no_dist (bool): Don't attempt to load in SPMD style. Default to False Returns: None. Examples >>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() >>> fs_storage_loader = torch.distributed._shard.checkpoint.FileSystemLoader("/checkpoint/1") >>> torch.distributed._shard.checkpoint.load_state_dict( >>> state_dict=model_state_dict, >>> storage_reader=fs_storage_loader, >>> ) >>> # module.load_state_dict() function might have customized steps >>> # to flush the state_dict, must call it to >>> # ensure correct behavior. >>> my_model.load_state_dict(model_state_dict) .. note:: load_state_dict uses collectives to coordinate reads across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by ``torch.cuda.current_device()`` and it is the user's responsibility to ensure that this is set so that each rank has an individual GPU, via ``torch.cuda.set_device()`` """ try: metadata = storage_reader.read_metadata() bytes_read_requests, tensor_read_requests = _reshard_and_prepare_read_request( state_dict=state_dict, metadata_from_storage=metadata ) bytes_futures = storage_reader.read_bytes(bytes_read_requests) tensor_futures = storage_reader.read_tensors(tensor_read_requests) bytes_futures.wait() # Addtional steps are required to convert the bytes to its original type # Note that this is NOT inplace, # it creating a new object and replace what's in the state dict for req in bytes_read_requests: # Ensure the BytesIO is rewound req.bytes.seek(0) state_dict[req.fqn] = torch.load(req.bytes) tensor_futures.wait() result = None except BaseException as e: result = e global_result: Optional[CheckpointException] = None if not no_dist: all_errors = [None] * dist.get_world_size(process_group) dist.all_gather_object( object_list=all_errors, obj=result, group=process_group) node_failures = cast(Dict[int, BaseException], {i: err for i, err in enumerate(all_errors) if err is not None}) if len(node_failures) > 0: global_result = CheckpointException("failed to read checkpoint", node_failures) elif result is not None: global_result = CheckpointException("failed to read storage", {coordinator_rank : result}) if global_result is not None: raise global_result