def _log_histogram_and_mean(log_key, val): try: SummaryWriterContext.add_histogram(log_key, val) SummaryWriterContext.add_scalar(f"{log_key}/mean", val.mean()) except ValueError: logger.warning( f"Cannot create histogram for key: {log_key}; " "this is likely because you have NULL value in your input; " f"value: {val}") raise
def get_log_prob(self, state: rlt.FeatureData, squashed_action: torch.Tensor): """ Action is expected to be squashed with tanh """ if self.use_l2_normalization: # TODO: calculate log_prob for l2 normalization # https://math.stackexchange.com/questions/3120506/on-the-distribution-of-a-normalized-gaussian-vector # http://proceedings.mlr.press/v100/mazoure20a/mazoure20a.pdf pass loc, scale_log = self._get_loc_and_scale_log(state) raw_action = torch.atanh(squashed_action) r = (raw_action - loc) / scale_log.exp() log_prob = self._normal_log_prob(r, scale_log) squash_correction = self._squash_correction(squashed_action) if SummaryWriterContext._global_step % 1000 == 0: SummaryWriterContext.add_histogram("actor/get_log_prob/loc", loc.detach().cpu()) SummaryWriterContext.add_histogram("actor/get_log_prob/scale_log", scale_log.detach().cpu()) SummaryWriterContext.add_histogram("actor/get_log_prob/log_prob", log_prob.detach().cpu()) SummaryWriterContext.add_histogram( "actor/get_log_prob/squash_correction", squash_correction.detach().cpu()) return torch.sum(log_prob - squash_correction, dim=1).reshape(-1, 1)
def forward(self, state: rlt.FeatureData): loc, scale_log = self._get_loc_and_scale_log(state) r = torch.randn_like(scale_log, device=scale_log.device) raw_action = loc + r * scale_log.exp() squashed_action = self._squash_raw_action(raw_action) squashed_loc = self._squash_raw_action(loc) if SummaryWriterContext._global_step % 1000 == 0: SummaryWriterContext.add_histogram("actor/forward/loc", loc.detach().cpu()) SummaryWriterContext.add_histogram("actor/forward/scale_log", scale_log.detach().cpu()) return rlt.ActorOutput( action=squashed_action, log_prob=self.get_log_prob(state, squashed_action), squashed_mean=squashed_loc, )
def _log_histogram_and_mean(name, key, x): SummaryWriterContext.add_histogram(f"dueling_network/{name}/{key}", x.detach().cpu()) SummaryWriterContext.add_scalar(f"dueling_network/{name}/mean_{key}", x.detach().mean().cpu())
def test_swallowing_histogram_value_error(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) with summary_writer_context(writer): SummaryWriterContext.add_histogram("bad_histogram", torch.ones(100, 1))