Example #1
0
    def load_checkpoint(self, state, rank: int):
        """
        Loads checkpoint if the checkpoint manager has been configured and
        at least one worker has already loaded the checkpoint
        """
        if not self.checkpoint_manager:
            # checkpoint not enabled
            return state

        # all gather `checkpoint_loaded` from all trainers, return true
        # if any trainer have ever loaded checkpoint
        any_checkpoint_loaded = (edist.all_gather_return_max_long(
            1 if self.checkpoint_loaded else 0) == 1)

        if any_checkpoint_loaded:
            # checkpoint already loaded by one of the existing trainer
            return state

        # we load checkpoint only if all trainers start from scratch. it is
        # not necessary to load checkpoint if there is a good trainer as new
        # trainer can sync state from it.
        # Start with simple scenario, we always ask one single trainer to
        # load checkpoint and other trainer sync from it
        if rank == 0:
            state = self._do_load_checkpoint(state)

        return state
Example #2
0
    def _sync_state(self, rank):
        # broadcast from the max rank with the biggest start index
        max_rank, _ = edist.all_gather_return_max_long(self.data_start_index)

        # Broadcast the state from max_rank
        buffer = io.BytesIO()
        self.save(buffer)
        state_tensor = torch.ByteTensor(list(buffer.getvalue()))
        state_size = torch.LongTensor([state_tensor.size()])
        dist.broadcast(state_size, src=max_rank)

        if rank != max_rank:
            state_tensor = torch.ByteTensor([0 for _ in range(state_size[0])])

        dist.broadcast(state_tensor, src=max_rank)

        buffer = io.BytesIO(state_tensor.numpy().tobytes())
        self.load(buffer)

        log.info(f"Rank {rank}: Model state synced from rank: {max_rank}\n"
                 f"\tbatch_size={self.total_batch_size}\n"
                 f"\tnum_data_workers={self.params.num_data_workers}\n"
                 f"\tdata_start_index={self.data_start_index}\n"
                 f"\titeration={self.iteration}\n"
                 f"\tepoch={self.epoch}/{self.num_epochs}")
        def _compute_most_tenured_rank(self, rank):
            logging.warning("RANK {}: syncing, I have {} updates".format(
                rank, self.task.num_updates))
            # Propagate state to new trainer processes.
            # First, figure out which process has a copy of the most recent
            # state by getting a copy of everybody's iteration counter.
            max_rank, max_num_updates = dist.all_gather_return_max_long(
                self.task.num_updates)

            logging.warning("RANK {}: rank {} has the most updates {}".format(
                rank, max_rank, max_num_updates))

            return max_rank
Example #4
0
 def broadcast_run(rank, world_size, input):
     return edist.all_gather_return_max_long(input[rank])