def evaluate(self, samples=1000):
        """
        Evaluate after training procedure finished

        :return: A log that contains information about validation
        """
        Outputs = torch.zeros(self.test_data_loader.n_samples,
                              self.model.num_classes, samples).to(self.device)
        targets = torch.zeros(self.test_data_loader.n_samples)

        self.model.eval()

        with torch.no_grad(
        ):  # torch.no_grad() 是一个上下文管理器,被该语句 wrap 起来的部分将不会track 梯度。
            start = 0
            for batch_idx, (data, target) in enumerate(self.test_data_loader):
                end = len(data) + start
                data, target = data.to(self.device), target.to(self.device)

                loss = 0
                outputs = torch.zeros(data.shape[0], self.model.num_classes,
                                      samples).to(self.device)

                if samples == 1:
                    out, _ = self.model(data)
                    loss = self.criterion(out, target)
                    outputs[:, :, 0] = out

                elif samples > 1:
                    mlpdw_cum = 0

                    for i in range(samples):
                        out, _ = self.model(data, sample=True)
                        mlpdw_i = self.criterion(out, target)
                        mlpdw_cum = mlpdw_cum + mlpdw_i
                        outputs[:, :, i] = out

                    mlpdw = mlpdw_cum / samples
                    loss = mlpdw

                Outputs[start:end, :, :] = outputs
                targets[start:end] = target
                start = end

                self.test_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.test_metrics.update(met.__name__,
                                             met(outputs, target, type="VD"))

        result = self.test_metrics.result()
        # print logged informations to the screen
        for key, value in result.items():
            self.logger.info('    {:15s}: {}'.format(str(key), value))

        # self._visualization(Outputs, targets)
        test_uncertainities(Outputs,
                            targets,
                            self.model.num_classes,
                            self.logger,
                            save_path=str(self.result_dir))
예제 #2
0
    def evaluate(self):
        """
        Evaluate after training procedure finished

        :return: A log that contains information about validation
        """
        Outputs = torch.zeros(self.test_data_loader.n_samples,
                              self.models[0].num_classes, self.n_ensembles)
        global targets

        for i, model in enumerate(self.models):
            model.eval()

            outputs = torch.zeros(self.test_data_loader.n_samples,
                                  model.num_classes).to(self.device)
            targets = torch.zeros(self.test_data_loader.n_samples)

            with torch.no_grad(
            ):  # torch.no_grad() 是一个上下文管理器,被该语句 wrap 起来的部分将不会track 梯度。
                start = 0
                for batch_idx, (data,
                                target) in enumerate(self.test_data_loader):
                    end = len(data) + start
                    data, target = data.to(self.device), target.to(self.device)

                    output = model(data)
                    outputs[start:end, :] = output
                    targets[start:end] = target
                    start = end

                    loss = self.criterion(output, target)
                    self.test_metrics.update('loss_' + str(i), loss.item())

                    for met in self.metric_ftns:
                        self.test_metrics.update(
                            met.__name__ + '_' + str(i),
                            met(output, target, type="DE"))

            self._visualization(outputs, targets, i)
            Outputs[:, :, i] = outputs

        self._info_ensemble(Outputs, targets)
        result = self.test_metrics.result()

        # print logged informations to the screen
        for key, value in result.items():
            self.logger.info('    {:15s}: {}'.format(str(key), value))

        # self._visualization_ensemble(Outputs, targets)
        test_uncertainities(Outputs,
                            targets,
                            self.models[0].num_classes,
                            self.logger,
                            save_path=str(self.result_dir))