コード例 #1
0
    def on_save_checkpoint(self) -> Dict:
        state_dict = super().on_save_checkpoint()

        if (
            self.trainer is not None
            and self.trainer.state._fault_tolerant_mode.is_enabled
            and self._data_fetcher is not None
            and not self._num_completed_batches_reached()  # did not finish
            and self.batch_progress.current.ready  # did start
        ):
            state = CombinedLoader._state_dict_fn(self._data_fetcher.dataloader_iter, self._has_completed())
            if state:
                state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(state)

        return state_dict
コード例 #2
0
    def on_save_checkpoint(self) -> Dict:
        state_dict = super().on_save_checkpoint()

        if (self.trainer.train_dataloader is None
                or self._num_completed_batches_reached()  # did not finish
                # TODO: fault-tolerance requires a minimum number of batches so probably should be > 0
                or self.batch_progress.current.ready == 0  # did not start
            ):
            return state_dict

        state_dict[
            "dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(
                self.trainer.train_dataloader.state_dict(
                    has_completed=self._has_completed()))
        return state_dict
コード例 #3
0
    def on_save_checkpoint(self) -> Dict:
        state_dict = super().on_save_checkpoint()

        if (
            self._data_fetcher is None
            or self._num_completed_batches_reached()  # did not finish
            # TODO: fault-tolerance requires a minimum number of batches so probably should be > 0
            or self.batch_progress.current.ready == 0  # did not start
        ):
            return state_dict

        # TODO: this should use `pytorch_lightning/trainer/supporters.py::CombinedLoader._state_dict_fn`
        state_to_save = "state" if self._has_completed() else "previous_state"
        state: Optional[MergedIteratorState] = getattr(self._data_fetcher.dataloader_iter, state_to_save, None)
        if state:
            state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(asdict(state))
        return state_dict
コード例 #4
0
    def on_save_checkpoint(self) -> Dict:
        state_dict = super().on_save_checkpoint()

        if (self.trainer is not None
                and self.trainer.state._fault_tolerant_mode.is_enabled
                and self.trainer.train_dataloader is not None
                and not self._num_completed_batches_reached()  # did not finish
                # TODO: fault-tolerance requires a minimum number of batches so probably should be > 0
                and self.batch_progress.current.ready  # did start
            ):
            loader: CombinedLoader = self.trainer.train_dataloader
            state = loader.state_dict(has_completed=self._has_completed())
            if state:
                state_dict[
                    "dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(
                        state)

        return state_dict