def reset(self) -> None:
     """Resets the loop's internal state."""
     self.iteration_count = 0
     self.predictions = PredictionCollection(self.trainer.global_rank,
                                             self.trainer.world_size)
     self.dl_max_batches = None
     self.dataloader_idx = None
     self.num_dataloaders = None
     self.outputs = []
    def reset(self) -> None:
        """Resets the loop's internal state."""
        self.predictions = PredictionCollection(self.trainer.global_rank,
                                                self.trainer.world_size)
        self._dl_max_batches = None
        self._num_dataloaders = None
        self.outputs = []

        if not self.restarting:
            self.batch_progress.current.reset()
    def setup(self, model, max_batches, dataloaders):
        # bookkeeping
        self.outputs = []
        self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size)

        # convert max_batches to list
        if isinstance(max_batches, int):
            max_batches = [max_batches] * len(dataloaders)

        self.max_batches = max_batches
        self.num_dataloaders = self._get_num_dataloaders(dataloaders)
示例#4
0
    def setup(self, model, max_batches, dataloaders):
        # copy properties for forward overrides
        self.trainer.model_connector.copy_trainer_model_properties(model)

        # bookkeeping
        self.outputs = []
        self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size)

        # convert max_batches to list
        if isinstance(max_batches, int):
            max_batches = [max_batches] * len(dataloaders)

        self.max_batches = max_batches
    def setup(self, model, max_batches, dataloaders):
        # enable eval mode
        model.zero_grad()
        model.eval()

        # copy properties for forward overrides
        self.trainer.copy_trainer_model_properties(model)

        # disable gradients to save memory
        torch.set_grad_enabled(False)

        # bookkeeping
        self.outputs = []
        self.predictions = PredictionCollection(self.trainer.global_rank,
                                                self.trainer.world_size)

        # convert max_batches to list
        if isinstance(max_batches, int):
            max_batches = [max_batches] * len(dataloaders)

        self.max_batches = max_batches
示例#6
0
    def _evaluate(self,
                  model: LightningModule,
                  dataloaders: List[DataLoader],
                  max_batches: Union[int, List[int]],
                  test_mode: bool = False):
        """Run evaluation code.

        Args:
            model: The model to evaluate.
            dataloaders: A list of PyTorch dataloaders.
            max_batches: An integer or list of integers with length of the number of dataloaders. Each
                entry is the number of batches to process in the corresponding dataloader.
            test_mode:
        """
        # enable eval mode
        model.zero_grad()
        model.eval()

        # copy properties for forward overrides
        self.copy_trainer_model_properties(model)

        # disable gradients to save memory
        torch.set_grad_enabled(False)

        # bookkeeping
        outputs = []
        predictions = PredictionCollection(self.global_rank, self.world_size)

        # convert max_batches to list
        if isinstance(max_batches, int):
            max_batches = [max_batches] * len(dataloaders)

        # --------------------------
        # ON_EVAL_EPOCH_START hook
        # --------------------------
        self.__call_eval_loop_hook_start(test_mode)

        # run validation
        for dataloader_idx, dataloader in enumerate(dataloaders):
            dl_outputs = []

            # on TPU we have to wrap it under the ParallelLoader
            if self.use_tpu:
                device = xm.xla_device(self.tpu_id)
                dataloader = xla_pl.ParallelLoader(dataloader, [device])
                dataloader = dataloader.per_device_loader(device)

            # each dataloader has a max num batches
            dl_max_batches = max_batches[dataloader_idx]

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when running on limited batches
                if batch_idx >= dl_max_batches:
                    break

                # callbacks
                if test_mode:
                    self.on_test_batch_start(batch, batch_idx, dataloader_idx)
                    if self.is_overridden('on_test_batch_start'):
                        model_ref = self.get_model()
                        with self.profiler.profile('on_test_batch_start'):
                            model_ref.on_test_batch_start(output)
                else:
                    self.on_validation_batch_start(batch, batch_idx,
                                                   dataloader_idx)
                    if self.is_overridden('on_validation_batch_start'):
                        model_ref = self.get_model()
                        with self.profiler.profile(
                                'on_validation_batch_start'):
                            model_ref.on_validation_batch_start(output)
                # -----------------
                # RUN EVALUATION STEP
                # -----------------
                if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
                    with torch.cuda.amp.autocast():
                        output = self.evaluation_forward(
                            model, batch, batch_idx, dataloader_idx, test_mode)
                else:
                    output = self.evaluation_forward(model, batch, batch_idx,
                                                     dataloader_idx, test_mode)

                is_result_obj = isinstance(output, Result)

                # track batch size for weighted average
                if is_result_obj:
                    output.track_batch_size(len(batch))

                # allow only EvalResult when using structured results (from val_step)
                if is_result_obj and not isinstance(output, EvalResult):
                    m = 'only EvalResults or dicts are allowed from validation_step'
                    raise MisconfigurationException(m)

                # ------------------
                # EVAL STEP END
                # ------------------
                # on dp / ddp2 might still want to do something with the batch parts
                eval_step_end_hook_name = 'test_step_end' if test_mode else 'validation_step_end'
                if self.is_overridden(eval_step_end_hook_name):
                    model_ref = self.get_model()
                    with self.profiler.profile(eval_step_end_hook_name):
                        eval_step_end = getattr(model_ref,
                                                eval_step_end_hook_name)
                        output = eval_step_end(output)

                elif is_result_obj and (self.use_dp or self.use_ddp2):
                    # result auto reduce
                    output.dp_reduce()

                # callbacks (on __batch_end)
                if test_mode:
                    self.on_test_batch_end(batch, batch_idx, dataloader_idx)
                    if self.is_overridden('on_test_batch_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('on_test_batch_end'):
                            model_ref.on_test_batch_end(output)
                else:
                    self.on_validation_batch_end(batch, batch_idx,
                                                 dataloader_idx)
                    if self.is_overridden('on_validation_batch_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('on_validation_batch_end'):
                            model_ref.on_validation_batch_end(output)

                # track outputs for collation
                if output is not None:

                    # Add step predictions to prediction collection to write later
                    do_write_predictions = is_result_obj and test_mode
                    if do_write_predictions:
                        predictions.add(output.pop('predictions', None))

                    dl_outputs.append(output)

                self.__eval_add_step_metrics(output, batch_idx)

                # track debug metrics
                self.dev_debugger.track_eval_loss_history(
                    test_mode, batch_idx, dataloader_idx, output)

            outputs.append(dl_outputs)

        # ---------------------
        # EVAL_EPOCH_END
        # ---------------------
        using_eval_result = len(outputs) > 0 and len(
            outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
        eval_results = self.__run_eval_epoch_end(test_mode, outputs,
                                                 dataloaders,
                                                 using_eval_result)

        # log callback metrics
        self.__update_callback_metrics(eval_results, using_eval_result)

        # Write predictions to disk if they're available.
        predictions.to_disk()

        # enable train mode again
        model.train()

        # enable gradients to save memory
        torch.set_grad_enabled(True)

        # --------------------------
        # ON_EVAL_EPOCH_END hook
        # --------------------------
        self.__call_eval_loop_hook_end(test_mode)

        return eval_results