コード例 #1
0
ファイル: accumulator.py プロジェクト: wintersurvival/adaptdl
    def __init__(self, *args, **kwargs):
        if current_dataloader() is not None:
            raise RuntimeError("accumulator may not be initialized during "
                               "dataloader iteration")
        epoch = current_epoch()
        count = _AccumulatorState.init_count[epoch]
        super().__init__("adaptdl-accumulator-epoch{}-{}".format(epoch, count))
        _AccumulatorState.init_count[epoch] += 1

        self.results_history = collections.defaultdict(list)
        self.results = dict(*args, **kwargs)
        self.updates = {}
コード例 #2
0
ファイル: data.py プロジェクト: rohitpandey13/adaptdl
    def __init__(self):
        if current_dataloader() is not None:
            raise RuntimeError("dataloader may not be initialized during "
                               "dataloader iteration")
        epoch = current_epoch()
        count = _AdaptiveDataLoaderState.init_count[epoch]
        super().__init__("adaptdl-dataloader-epoch{}-{}".format(epoch, count))
        _AdaptiveDataLoaderState.init_count[epoch] += 1

        self.current_index = 0  # Index within the current dataloader loop.
        self.end_index = 0  # End index of the current DataLoader loop.
        self.last_position = {}  # Epoch -> position of last completed loop.
コード例 #3
0
ファイル: data.py プロジェクト: Milkigit/adaptdl
 def context(self):
     """
     All iterators should be iterated under this context. It ensures
     proper cleanup of elastic context at the end of each epoch.
     """
     epoch = current_epoch()
     try:
         if AdaptiveDataLoaderHelper._current is not None:
             raise RuntimeError("overlapping dataloader \
                                 iterations detected")
         AdaptiveDataLoaderHelper._current = self
         yield
     finally:
         self._state.current_index = 0
         self._state.end_index = 0
         self._state.last_position[epoch] = self._position[epoch]
         self._position[epoch] += 1
         AdaptiveDataLoaderHelper._current = None
コード例 #4
0
ファイル: data.py プロジェクト: Milkigit/adaptdl
    def skipdone(self):
        """
        Should be called just after entering the `_elastic` context to make
        sure that the dataloader loop is not replayed if has already finished
        before a restart.
        """

        epoch = current_epoch()
        position = self._position[epoch]
        if position <= self._state.last_position.get(epoch, -1):
            # Already completed the dataloader loop at the current
            # position, skip this loop and keep replaying the application
            # code.
            LOG.info("skipping %s loop at position %s in epoch %s",
                     self.__class__.__name__, position, epoch)
            self._position[epoch] += 1
            return True
        else:
            return False
コード例 #5
0
ファイル: accumulator.py プロジェクト: wintersurvival/adaptdl
    def synchronized(self):
        """
        A context manager which can be used to define the code to execute in
        *synchronized* mode. Within the context manager, any code can interact
        with this accumulator as if it were a regular Python ``dict``. The
        application must ensure that whatever operations performed within this
        context block are the same across all replicas.

        .. warning::
            Entering this context manager is a distributed synchronization
            point! Please ensure that all replicas enter this context manager
            at the same point in their code.
        """
        if self._synchronized is not None:
            # Already synchronized, don't need to do anything.
            yield self
            return
        epoch = current_epoch()
        # Remove saved results from all finished epochs. Since finished
        # epochs are never replayed, they should never be needed again.
        for key in list(self._state.results_history.keys()):
            if key is not None and key < epoch:
                self._state.results_history.pop(key)
        # Get the number of synchronizations so far in the current epoch.
        count = self._sync_count[epoch]
        self._sync_count[epoch] += 1
        results_list = self._state.results_history[epoch]
        assert count <= len(results_list)
        if count < len(results_list):
            # Results for this synchronization are saved in the history.
            self._synchronized = results_list[count]
            self._state.updates.clear()
        else:
            self._state.sync()  # Sync results and updates across replicas.
            if current_dataloader() is None:
                # Only save into results history if outside of a dataloader
                # iteration, since code inside iterations are not replayed.
                results_list.append(copy.deepcopy(self._state.results))
            self._synchronized = self._state.results
        try:
            yield self
        finally:
            self._synchronized = None
コード例 #6
0
ファイル: data.py プロジェクト: wintersurvival/adaptdl
    def __iter__(self):
        """
        Iterate over batches of data. When adaptive batch size is disabled,
        stops after the entire dataset has been processed once in total by all
        replicas. This means if there are K replicas, then this method will
        iterate over ~1/K of the dataset. When adaptive batch size is enabled,
        stops after making enough statistical progress roughly equivalent to
        one pass over the dataset with non-adaptive batch size. In this case,
        the dataset may be processed more than once.

        A checkpoint-restart may be triggered in-between each batch. In this
        case, the current iteration state will be saved and restored after the
        restart, and continue where it left off.
        """
        epoch = current_epoch()
        num_replicas = adaptdl.env.num_replicas()
        with self._elastic.context():
            if self._elastic.skipdone():
                return
            done = False
            while not done:
                self.sampler.set_epoch(epoch,
                                       index=self._elastic.current_index)
                self.batch_sampler.batch_size = self._elastic._sync_local_bsz()
                for idx, batch in enumerate(super().__iter__()):
                    with self._elastic.profile(self.training and idx >= 1):
                        yield batch
                        # Increment by the number of data samples processed
                        self._elastic.current_index += \
                            num_replicas * self.batch_sampler.batch_size
                        if self._elastic.max_batch_size is not None and \
                                get_progress() >= len(self.dataset) * \
                                (epoch + 1) / self.batch_size:
                            done = True
                            break
                if self._elastic.max_batch_size is None:
                    done = True
                self._elastic.current_index -= \
                    self._elastic.current_index % -len(self.dataset)