Пример #1
0
    def _eval(self, env, dataLoader):
        self._debugPrint("Eval step")
        self.model.eval()
        B = list()
        X = list()
        for x in dataLoader:
            X.append(x)
            if self.env.DoNormalizationOnObs:
                x = (x - self._obsMean) / self._obsStd
            batchB = self.model.Encode(x)
            B.append(batchB)
        B = torch.cat(B, 0)
        X = torch.cat(X, 0)
        C, qErrors = env.Eval(X, B)
        del B
        B = list()
        for x in dataLoader:
            if self.env.DoNormalizationOnObs:
                x = (x - self._obsMean) / self._obsStd
            batchB = self.model.Encode(x,
                                       icm=True,
                                       C=C.repeat(self._nParallels, 1, 1),
                                       shift=self._obsMeanRepeat,
                                       scale=self._obsStdRepeat)
            B.append(batchB)
        B = torch.cat(B, 0)
        icmQErrors = QuantizationError(X, C, B)
        self.logger.info("After ICM: %f, %.0f%% samples are better.",
                         icmQErrors.mean(),
                         (icmQErrors < qErrors).sum().float() / len(qErrors) *
                         100.)

        del B, X, C
        self.model.train()
Пример #2
0
 def Eval(self,
          x: torch.Tensor,
          b: torch.Tensor,
          additionalMsg: str = None) -> torch.Tensor:
     newCodebook = self.solver.solve(x, b, alternateWhenOutlier=True)
     # assignCodes = self._randPerm(assignCodes)
     if b.shape[-1] == self._m * self._k:
         newQError = ((
             x - b @ newCodebook.reshape(self._m * self._k, -1))**2).sum(-1)
     else:
         newQError = QuantizationError(x, newCodebook, b)
     self.logger.info("[%4d %s]QError: %3.2f", self._step, additionalMsg
                      or "Eval", newQError.mean())
     if self.summaryWriter is not None:
         self.summaryWriter.add_scalar("eval/QError",
                                       newQError.mean(),
                                       global_step=self._step)
     return newCodebook, newQError
Пример #3
0
 def Step(self,
          x: torch.Tensor,
          b: torch.Tensor,
          logStat: bool = True) -> (torch.Tensor, torch.Tensor):
     newCodebook = self.solver.solve(x, b, alternateWhenOutlier=True)
     if b.shape[-1] == self._m * self._k:
         newQError = ((
             x - b @ newCodebook.reshape(self._m * self._k, -1))**2).sum(-1)
     else:
         newQError = QuantizationError(x.cuda(), newCodebook, b.cuda())
     if self._oldQError is None:
         self._oldQError = newQError
         self._meanQE = self._oldQError.mean()
     if self._doNormalizeOnRew:
         rewards = (self._oldQError - newQError)
         if self._firstRun:
             self._firstRun = False
         else:
             if logStat:
                 _, variance = self.Estimate(("reward", rewards), (0, ))
                 # TODO: mean or not mean?
                 rewards = rewards / (variance + 1e-8).sqrt()
             else:
                 rewards = rewards / (self._variance["estimate/reward"] +
                                      1e-8).sqrt()
     else:
         # [N, ]
         rewards = (self._oldQError - newQError) / self._meanQE
     currentQError = newQError.mean()
     self.logger.debug("[%4d Train]QError: %3.2f", self._step,
                       currentQError)
     if self.summaryWriter is not None:
         self.summaryWriter.add_scalar("env/QError",
                                       currentQError,
                                       global_step=self._step)
         self.summaryWriter.add_histogram("env/Reward",
                                          rewards,
                                          global_step=self._step)
     self._step += 1
     if self._doNormalizeOnObs:
         # mean, variance = self.Estimate(("codebook", newCodebook), (1, ))
         # newCodebook = (newCodebook - mean) / (variance + 1e-8).sqrt()
         if not hasattr(self, "_obsMean"):
             raise AttributeError(
                 f"Not feed obs mean and var with DoNormalizationOnObs = {self._doNormalizeOnObs}"
             )
         newCodebook = (newCodebook - self._obsMean) / self._obsStd
     self._codebook.data = newCodebook
     del newCodebook
     if logStat:
         self._estimateQEStat(currentQError)
     return rewards.to(x.device), currentQError.to(x.device)