コード例 #1
0
def test_progress_raises():
    with pytest.raises(ValueError,
                       match="instances should be of the same class"):
        Progress(ReadyCompletedTracker(), ProcessedTracker())

    p = Progress(ReadyCompletedTracker(), ReadyCompletedTracker())
    with pytest.raises(
            TypeError,
            match="ReadyCompletedTracker` doesn't have a `started` attribute"):
        p.increment_started()
    with pytest.raises(
            TypeError,
            match="ReadyCompletedTracker` doesn't have a `processed` attribute"
    ):
        p.increment_processed()
コード例 #2
0
def test_progress_increment(attr):
    p = Progress()
    fn = getattr(p, f"increment_{attr}")
    fn()
    expected = ProcessedTracker(**{attr: 1})
    assert p.total == expected
    assert p.current == expected
コード例 #3
0
def test_progress_increment_sequence():
    """Test sequence for incrementing."""
    batch = Progress()

    batch.increment_ready()
    assert batch.total == ProcessedTracker(ready=1)
    assert batch.current == ProcessedTracker(ready=1)

    batch.increment_started()
    assert batch.total == ProcessedTracker(ready=1, started=1)
    assert batch.current == ProcessedTracker(ready=1, started=1)

    batch.increment_processed()
    assert batch.total == ProcessedTracker(ready=1, started=1, processed=1)
    assert batch.current == ProcessedTracker(ready=1, started=1, processed=1)

    batch.increment_completed()
    assert batch.total == ProcessedTracker(ready=1,
                                           started=1,
                                           processed=1,
                                           completed=1)
    assert batch.current == ProcessedTracker(ready=1,
                                             started=1,
                                             processed=1,
                                             completed=1)
コード例 #4
0
    def __init__(self,
                 label_epoch_frequency: int,
                 inference_iteration: int = 2,
                 should_reset_weights: bool = True):
        """The `ActiveLearning Loop` describes the following training procedure. This loop is connected with the
        `ActiveLearningTrainer`

        Example::

            while unlabelled data or budget criteria not reached:

                if labelled data
                    trainer.fit(model, labelled data)

                if unlabelled data:
                    predictions = trainer.predict(model, unlabelled data)
                    uncertainties = heuristic(predictions)
                    request labellelisation for the sample with highest uncertainties under a given budget

        Args:
            label_epoch_frequency: Number of epoch to train on before requesting labellisation.
            inference_iteration: Number of inference to perform to compute uncertainty.
        """
        super().__init__()
        self.label_epoch_frequency = label_epoch_frequency
        self.inference_iteration = inference_iteration
        self.should_reset_weights = should_reset_weights
        self.fit_loop: Optional[FitLoop] = None
        self.progress = Progress()
        self._model_state_dict: Optional[Dict[str, torch.Tensor]] = None
        self._datamodule_state_dict: Optional[Dict[str, Any]] = None
        self._lightning_module: Optional[flash.Task] = None
コード例 #5
0
 def __init__(self) -> None:
     super().__init__()
     self.dataloader: Optional[Iterator] = None
     self._dl_max_batches: Optional[int] = None
     self._num_dataloaders: Optional[int] = None
     self.outputs: List[STEP_OUTPUT] = []
     self.batch_progress = Progress()
コード例 #6
0
 def __init__(self,
              min_epochs: Optional[int] = None,
              max_epochs: Optional[int] = None):
     super().__init__()
     self.max_epochs = max_epochs
     self.min_epochs = min_epochs
     self.epoch_loop: Optional[TrainingEpochLoop] = None
     self.epoch_progress = Progress()
コード例 #7
0
 def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None):
     super().__init__()
     self.max_epochs = max_epochs
     self.min_epochs = min_epochs
     self.epoch_loop: Optional[TrainingEpochLoop] = None
     self.epoch_progress = Progress()
     # caches the loaded dataloader state until dataloader objects are available
     self._dataloader_state_dict: Dict[str, Any] = {}
コード例 #8
0
    def __init__(self) -> None:
        super().__init__()
        self.return_predictions = False
        self.predictions: List[Any] = []
        self.current_batch_indices: List[int] = []
        self.batch_progress = Progress()

        self._dl_max_batches = 0
        self._num_dataloaders = 0
        self._warning_cache = WarningCache()
        self._seen_batch_indices: List[List[int]] = []
コード例 #9
0
def test_loop_progress_increment_sequence():
    """ Test sequences for incrementing batches reads and epochs. """
    p = LoopProgress(batch=Progress(total=Tracker(started=None)))

    p.batch.increment_ready()
    assert p.batch.total == Tracker(ready=1, started=None)
    assert p.batch.current == Tracker(ready=1)

    p.batch.increment_started()
    assert p.batch.total == Tracker(ready=1, started=None)
    assert p.batch.current == Tracker(ready=1)

    p.batch.increment_processed()
    assert p.batch.total == Tracker(ready=1, started=None, processed=1)
    assert p.batch.current == Tracker(ready=1, processed=1)

    p.batch.increment_completed()
    assert p.batch.total == Tracker(ready=1,
                                    started=None,
                                    processed=1,
                                    completed=1)
    assert p.batch.current == Tracker(ready=1, processed=1, completed=1)

    assert p.epoch.total == Tracker()
    assert p.epoch.current == Tracker()
    p.increment_epoch_completed()
    assert p.batch.total == Tracker(ready=1,
                                    started=None,
                                    processed=1,
                                    completed=1)
    assert p.batch.current == Tracker()
    assert p.epoch.total == Tracker(completed=1)
    assert p.epoch.current == Tracker()

    p.batch.increment_ready()
    assert p.batch.total == Tracker(ready=2,
                                    started=None,
                                    processed=1,
                                    completed=1)
    assert p.batch.current == Tracker(ready=1)
    assert p.epoch.total == Tracker(completed=1)
    assert p.epoch.current == Tracker()

    p.reset_on_epoch()
    assert p.batch.total == Tracker(ready=2,
                                    started=None,
                                    processed=1,
                                    completed=1)
    assert p.batch.current == Tracker()
    assert p.epoch.total == Tracker(completed=1)
    assert p.epoch.current == Tracker()
コード例 #10
0
ファイル: fit_loop.py プロジェクト: kazhang/pytorch-lightning
    def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None):
        super().__init__()
        # Allow max_epochs or max_steps to be zero, since this will be handled by fit_loop.done
        if max_epochs and max_epochs < -1:
            raise MisconfigurationException(
                f"`max_epochs` must be a positive integer or -1. You passed in {max_epochs}."
            )

        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.epoch_loop: Optional[TrainingEpochLoop] = None
        self.epoch_progress = Progress()
        # caches the loaded dataloader state until dataloader objects are available
        self._dataloader_state_dict: Dict[str, Any] = {}
コード例 #11
0
    def __init__(self, min_steps: int, max_steps: int):
        super().__init__()
        self.min_steps: int = min_steps
        self.max_steps: int = max_steps

        self.global_step: int = 0
        # manually tracking which is the last batch is necessary for iterable dataset support
        self.is_last_batch: Optional[bool] = None
        self.batch_progress = Progress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop: Optional[TrainingBatchLoop] = None
        self.val_loop: Optional["loops.EvaluationLoop"] = None

        self._results = ResultCollection(training=True)
        self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
コード例 #12
0
    def __init__(
        self,
        min_epochs: Optional[int] = 1,
        max_epochs: int = 1000,
    ) -> None:
        super().__init__()
        if max_epochs < -1:
            # Allow max_epochs to be zero, since this will be handled by fit_loop.done
            raise MisconfigurationException(
                f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}."
            )

        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.epoch_loop: Optional[TrainingEpochLoop] = None
        self.epoch_progress = Progress()
        self._is_fresh_start_epoch: bool = True
コード例 #13
0
    def __init__(self, min_steps: int, max_steps: int):
        super().__init__()
        self.min_steps: int = min_steps
        self.max_steps: int = max_steps
        self.global_step: int = 0
        # the total batch index across all epochs
        self.total_batch_idx: int = 0
        self.is_last_batch: Optional[bool] = None
        self.batch_progress = Progress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop: Optional[TrainingBatchLoop] = None
        self.val_loop: Optional["loops.EvaluationLoop"] = None

        self._results = ResultCollection(training=True)
        self._dataloader_idx: Optional[int] = None
        self._warning_cache: WarningCache = WarningCache()
        self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
コード例 #14
0
    def __init__(
        self,
        min_epochs: int = 0,
        max_epochs: int = 1000,
    ) -> None:
        super().__init__()
        if max_epochs < -1:
            # Allow max_epochs to be zero, since this will be handled by fit_loop.done
            raise MisconfigurationException(
                f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}."
            )

        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.epoch_loop = TrainingEpochLoop()
        self.epoch_progress = Progress()

        self._is_fresh_start_epoch: bool = True
        self._outputs: _EPOCH_OUTPUTS_TYPE = []
        self._data_fetcher: Optional[AbstractDataFetcher] = None
コード例 #15
0
    def __init__(self, min_steps: int, max_steps: int):
        super().__init__()
        self.min_steps: int = min_steps

        if max_steps and max_steps < -1:
            raise MisconfigurationException(
                f"`max_steps` must be a positive integer or -1. You passed in {max_steps}."
            )
        self.max_steps: int = max_steps

        self.global_step: int = 0
        # manually tracking which is the last batch is necessary for iterable dataset support
        self.is_last_batch: Optional[bool] = None
        self.batch_progress = Progress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop: Optional[TrainingBatchLoop] = None
        self.val_loop: Optional["loops.EvaluationLoop"] = None

        self._results = ResultCollection(training=True)
        self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
コード例 #16
0
def test_epoch_loop_progress_increment_sequence():
    """Test sequences for incrementing batches reads and epochs."""
    batch = Progress()

    batch.increment_ready()
    assert batch.total == Tracker(ready=1)
    assert batch.current == Tracker(ready=1)

    batch.increment_started()
    assert batch.total == Tracker(ready=1, started=1)
    assert batch.current == Tracker(ready=1, started=1)

    batch.increment_processed()
    assert batch.total == Tracker(ready=1, started=1, processed=1)
    assert batch.current == Tracker(ready=1, started=1, processed=1)

    batch.increment_completed()
    assert batch.total == Tracker(ready=1, started=1, processed=1, completed=1)
    assert batch.current == Tracker(ready=1,
                                    started=1,
                                    processed=1,
                                    completed=1)
コード例 #17
0
    def __init__(self, min_steps: int, max_steps: int):
        super().__init__()
        self.min_steps: int = min_steps
        self.max_steps: int = max_steps
        self.global_step: int = 0
        # the total batch index across all epochs
        self.total_batch_idx: int = 0
        # the current split index when the batch gets split into chunks in truncated backprop through time
        self.split_idx: Optional[int] = None
        # the number of batches seen this run, updates immediately after batch_loop.run()
        # TODO: replace by progress tracking
        self.batches_seen: int = 0
        self.is_last_batch: Optional[bool] = None
        self.batch_progress = Progress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop: Optional[TrainingBatchLoop] = None
        self.val_loop: Optional["loops.EvaluationLoop"] = None

        self._results = ResultCollection(training=True)
        self._dataloader_idx: Optional[int] = None
        self._warning_cache: WarningCache = WarningCache()
        self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
コード例 #18
0
def test_progress_from_defaults():
    actual = Progress.from_defaults(StartedTracker, completed=5)
    expected = Progress(total=StartedTracker(completed=5),
                        current=StartedTracker(completed=5))
    assert actual == expected
コード例 #19
0
def test_deepcopy():
    _ = deepcopy(BaseProgress())
    _ = deepcopy(Progress())
    _ = deepcopy(ProcessedTracker())
コード例 #20
0
def test_base_progress_from_defaults():
    actual = Progress.from_defaults(completed=5, started=None)
    expected = Progress(total=Tracker(started=None, completed=5), current=Tracker(started=None, completed=5))
    assert actual == expected
コード例 #21
0
 def _reset_fitting(self):
     self.trainer.state.fn = TrainerFn.FITTING
     self.trainer.training = True
     self.trainer.lightning_module.on_train_dataloader()
     self._connect(self._lightning_module)
     self.fit_loop.epoch_progress = Progress()