예제 #1
0
    def eval_queue(self,
                   queue,
                   criterions,
                   steps=1,
                   mode="eval",
                   aggregate_fns=None,
                   **kwargs):
        self._set_mode(mode)

        aggr_ans = []
        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for _ in range(steps):
                data = next(queue)
                # print("{}/{}\r".format(i, steps), end="")
                data = _to_device(data, self.get_device())
                outputs = self.forward_data(data[0], **kwargs)
                self._set_mode("eval")  # mAP only is calculated in "eval" mode
                ans = utils.flatten_list(
                    [c(data[0], outputs, data[1]) for c in criterions])
                aggr_ans.append(ans)
                self._set_mode(mode)
        aggr_ans = np.asarray(aggr_ans).transpose()
        if aggregate_fns is None:
            # by default, aggregate batch rewards with MEAN
            aggregate_fns = [lambda perfs: np.mean(perfs) if len(perfs) > 0 else 0.]\
                * len(aggr_ans)
        return [aggr_fn(ans) for aggr_fn, ans in zip(aggregate_fns, aggr_ans)]
예제 #2
0
    def eval_queue(self,
                   queue,
                   criterions,
                   steps=1,
                   mode="eval",
                   aggregate_fns=None,
                   **kwargs):
        # BN running statistics calibration
        if self.calib_bn_batch > 0:
            calib_data = [next(queue) for _ in range(self.calib_bn_batch)]
            self.calib_bn(calib_data)

        self._set_mode(mode)

        aggr_ans = []
        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for i in range(steps):
                if i < self.calib_bn_batch:
                    data = calib_data[i]
                else:
                    data = next(queue)
                data = _to_device(data, self.get_device())
                outputs = self.forward_data(data[0], **kwargs)
                ans = utils.flatten_list(
                    [c(data[0], outputs, data[1]) for c in criterions])
                aggr_ans.append(ans)
        aggr_ans = np.asarray(aggr_ans).transpose()
        if aggregate_fns is None:
            # by default, aggregate batch rewards with MEAN
            aggregate_fns = [lambda perfs: np.mean(perfs) if len(perfs) > 0 else 0.]\
                            * len(aggr_ans)
        return [aggr_fn(ans) for aggr_fn, ans in zip(aggregate_fns, aggr_ans)]
예제 #3
0
파일: base.py 프로젝트: zeta1999/aw_nas
    def forward_queue(self, queue, steps=1, mode=None, **kwargs):
        self._set_mode(mode)

        outputs = []
        for _ in range(steps):
            data = next(queue)
            data = _to_device(data, self.get_device())
            outputs.append(self.forward_data(*data, **kwargs))
        return torch.cat(outputs, dim=0)
예제 #4
0
    def train_queue(self,
                    queue,
                    optimizer,
                    criterion=lambda i, l, t: nn.CrossEntropyLoss()(l, t),
                    eval_criterions=None,
                    steps=1,
                    aggregate_fns=None,
                    **kwargs):
        assert steps > 0

        self._set_mode("train")

        aggr_ans = []
        for _ in range(steps):
            data = next(queue)
            data = _to_device(data, self.get_device())
            _, targets = data
            outputs = self.forward_data(*data, **kwargs)
            loss = criterion(data[0], outputs, targets)
            if eval_criterions:
                ans = utils.flatten_list(
                    [c(data[0], outputs, targets) for c in eval_criterions])
                aggr_ans.append(ans)
            self.zero_grad()
            loss.backward()
            optimizer.step()
            self.clear_cache()

        if eval_criterions:
            aggr_ans = np.asarray(aggr_ans).transpose()
            if aggregate_fns is None:
                # by default, aggregate batch rewards with MEAN
                aggregate_fns = [
                    lambda perfs: np.mean(perfs) if len(perfs) > 0 else 0.0
                ] * len(aggr_ans)
            return [
                aggr_fn(ans) for aggr_fn, ans in zip(aggregate_fns, aggr_ans)
            ]
        return []
예제 #5
0
    def eval_queue(self,
                   queue,
                   criterions,
                   steps=1,
                   mode="eval",
                   aggregate_fns=None,
                   **kwargs):
        # BN running statistics calibration
        if self.calib_bn_num > 0:
            # check `calib_bn_num` first
            calib_num = 0
            calib_data = []
            calib_batch = 0
            while calib_num < self.calib_bn_num:
                if calib_batch == steps:
                    utils.getLogger("robustness plugin.{}".format(self.__class__.__name__)).warn(
                        "steps (%d) reached, true calib bn num (%d)", calib_num, steps)
                    break
                calib_data.append(next(queue))
                calib_num += len(calib_data[-1][1])
                calib_batch += 1
            self.calib_bn(calib_data)
        elif self.calib_bn_batch > 0:
            if self.calib_bn_batch > steps:
                utils.getLogger("robustness plugin.{}".format(self.__class__.__name__)).warn(
                    "eval steps (%d) < `calib_bn_batch` (%d). Only use %d batches.",
                    steps, self.calib_bn_steps, steps)
                calib_bn_batch = steps
            else:
                calib_bn_batch = self.calib_bn_batch
            # check `calib_bn_batch` then
            calib_data = [next(queue) for _ in range(calib_bn_batch)]
            self.calib_bn(calib_data)
        else:
            calib_data = []

        self._set_mode("eval") # Use eval mode after BN calibration

        aggr_ans = []
        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for i in range(steps):
                if i < len(calib_data):# self.calib_bn_batch:
                    data = calib_data[i]
                else:
                    data = next(queue)
                data = _to_device(data, self.get_device())
                outputs = self.forward_data(data[0], **kwargs)
                ans = utils.flatten_list(
                    [c(data[0], outputs, data[1]) for c in criterions])
                aggr_ans.append(ans)
                del outputs
                print("\reva step {}/{} ".format(i, steps), end="", flush=True)

        aggr_ans = np.asarray(aggr_ans).transpose()

        if aggregate_fns is None:
            # by default, aggregate batch rewards with MEAN
            aggregate_fns = [lambda perfs: np.mean(perfs) if len(perfs) > 0 else 0.]\
                            * len(aggr_ans)
        return [aggr_fn(ans) for aggr_fn, ans in zip(aggregate_fns, aggr_ans)]