def _eval(self, env, dataLoader): self._debugPrint("Eval step") self.model.eval() B = list() X = list() for x in dataLoader: X.append(x) if self.env.DoNormalizationOnObs: x = (x - self._obsMean) / self._obsStd batchB = self.model.Encode(x) B.append(batchB) B = torch.cat(B, 0) X = torch.cat(X, 0) C, qErrors = env.Eval(X, B) del B B = list() for x in dataLoader: if self.env.DoNormalizationOnObs: x = (x - self._obsMean) / self._obsStd batchB = self.model.Encode(x, icm=True, C=C.repeat(self._nParallels, 1, 1), shift=self._obsMeanRepeat, scale=self._obsStdRepeat) B.append(batchB) B = torch.cat(B, 0) icmQErrors = QuantizationError(X, C, B) self.logger.info("After ICM: %f, %.0f%% samples are better.", icmQErrors.mean(), (icmQErrors < qErrors).sum().float() / len(qErrors) * 100.) del B, X, C self.model.train()
def Eval(self, x: torch.Tensor, b: torch.Tensor, additionalMsg: str = None) -> torch.Tensor: newCodebook = self.solver.solve(x, b, alternateWhenOutlier=True) # assignCodes = self._randPerm(assignCodes) if b.shape[-1] == self._m * self._k: newQError = (( x - b @ newCodebook.reshape(self._m * self._k, -1))**2).sum(-1) else: newQError = QuantizationError(x, newCodebook, b) self.logger.info("[%4d %s]QError: %3.2f", self._step, additionalMsg or "Eval", newQError.mean()) if self.summaryWriter is not None: self.summaryWriter.add_scalar("eval/QError", newQError.mean(), global_step=self._step) return newCodebook, newQError
def Step(self, x: torch.Tensor, b: torch.Tensor, logStat: bool = True) -> (torch.Tensor, torch.Tensor): newCodebook = self.solver.solve(x, b, alternateWhenOutlier=True) if b.shape[-1] == self._m * self._k: newQError = (( x - b @ newCodebook.reshape(self._m * self._k, -1))**2).sum(-1) else: newQError = QuantizationError(x.cuda(), newCodebook, b.cuda()) if self._oldQError is None: self._oldQError = newQError self._meanQE = self._oldQError.mean() if self._doNormalizeOnRew: rewards = (self._oldQError - newQError) if self._firstRun: self._firstRun = False else: if logStat: _, variance = self.Estimate(("reward", rewards), (0, )) # TODO: mean or not mean? rewards = rewards / (variance + 1e-8).sqrt() else: rewards = rewards / (self._variance["estimate/reward"] + 1e-8).sqrt() else: # [N, ] rewards = (self._oldQError - newQError) / self._meanQE currentQError = newQError.mean() self.logger.debug("[%4d Train]QError: %3.2f", self._step, currentQError) if self.summaryWriter is not None: self.summaryWriter.add_scalar("env/QError", currentQError, global_step=self._step) self.summaryWriter.add_histogram("env/Reward", rewards, global_step=self._step) self._step += 1 if self._doNormalizeOnObs: # mean, variance = self.Estimate(("codebook", newCodebook), (1, )) # newCodebook = (newCodebook - mean) / (variance + 1e-8).sqrt() if not hasattr(self, "_obsMean"): raise AttributeError( f"Not feed obs mean and var with DoNormalizationOnObs = {self._doNormalizeOnObs}" ) newCodebook = (newCodebook - self._obsMean) / self._obsStd self._codebook.data = newCodebook del newCodebook if logStat: self._estimateQEStat(currentQError) return rewards.to(x.device), currentQError.to(x.device)