Exemple #1
0
def _sync_params_and_buffers(
    module,
    process_group,
    broadcast_bucket_size,
    src,
    params_and_buffers_to_ignore,
):
    """
    Syncs ``module``'s parameters and buffers state so that all ranks contain
    the same module state across all ranks. Note that this API assumes that all
    parameter shapes are consistent before running the synchronization. This can
    be checked with ``_verify_param_shape_across_processes``.
    """
    module_states = []
    for name, param in module.named_parameters():
        if name not in params_and_buffers_to_ignore:
            module_states.append(param.detach())

    for name, buffer in module.named_buffers():
        if name not in params_and_buffers_to_ignore:
            module_states.append(buffer.detach())

    if len(module_states) > 0:
        dist._broadcast_coalesced(process_group, module_states,
                                  broadcast_bucket_size, src)
Exemple #2
0
 def _broadcast_buffers(self):
     """Explicitly synchronize buffers across all devices."""
     if self.distributed_model is None:
         return
     buffers = list(self.base_model.buffers())
     if len(buffers) > 0:
         logging.info("Synchronizing buffers before evaluation.")
         _broadcast_coalesced(self.distributed_model.process_group, buffers,
                              256 * 1024 * 1024)
Exemple #3
0
def _sync_params_and_buffers(
    process_group: dist.ProcessGroup,
    module_states: List[torch.Tensor],
    broadcast_bucket_size: int,
    src: int,
):
    """
    Synchronizes ``module_states`` (list of tensors) across all processes by
    broadcasting them from rank 0.
    """
    if len(module_states) > 0:
        dist._broadcast_coalesced(process_group, module_states,
                                  broadcast_bucket_size, src)
Exemple #4
0
 def _distributed_broadcast_coalesced(self, tensors, buffer_size):
     dist._broadcast_coalesced(self.process_group, tensors, buffer_size)
Exemple #5
0
 def _distributed_broadcast_coalesced(self,
                                      tensors,
                                      buffer_size,
                                      authoritative_rank=0):
     dist._broadcast_coalesced(self.process_group, tensors, buffer_size,
                               authoritative_rank)