示例#1
0
文件: cpe.py 项目: t-triobox/ReAgent
    def log_to_tensorboard(self, metric_name: str) -> None:
        self.check_estimates_exist()

        def none_to_zero(x: Optional[float]) -> float:
            if x is None or math.isnan(x):
                return 0.0
            return x

        for name, value in [
            (
                "CPE/{}/Direct_Method_Reward".format(metric_name),
                # pyre-fixme[16]: `Optional` has no attribute `normalized`.
                self.direct_method.normalized,
            ),
            (
                "CPE/{}/IPS_Reward".format(metric_name),
                self.inverse_propensity.normalized,
            ),
            (
                "CPE/{}/Doubly_Robust_Reward".format(metric_name),
                self.doubly_robust.normalized,
            ),
            (
                "CPE/{}/Sequential_Doubly_Robust".format(metric_name),
                self.sequential_doubly_robust.normalized,
            ),
            (
                "CPE/{}/Weighted_Sequential_Doubly_Robust".format(metric_name),
                self.weighted_doubly_robust.normalized,
            ),
            ("CPE/{}/MAGIC".format(metric_name), self.magic.normalized),
        ]:
            SummaryWriterContext.add_scalar(name, none_to_zero(value))
示例#2
0
 def test_swallowing_exception(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock(
             side_effect=NotImplementedError("test"))
         writer.exceptions_to_ignore = (NotImplementedError, KeyError)
         with summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
示例#3
0
 def test_not_swallowing_exception(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock(
             side_effect=NotImplementedError("test"))
         with self.assertRaisesRegex(
                 NotImplementedError,
                 "test"), summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
示例#4
0
 def test_writing(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock()
         with summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
         writer.add_scalar.assert_called_once_with("test",
                                                   torch.ones(1),
                                                   global_step=0)
示例#5
0
def _log_histogram_and_mean(log_key, val):
    try:
        SummaryWriterContext.add_histogram(log_key, val)
        SummaryWriterContext.add_scalar(f"{log_key}/mean", val.mean())
    except ValueError:
        logger.warning(
            f"Cannot create histogram for key: {log_key}; "
            "this is likely because you have NULL value in your input; "
            f"value: {val}")
        raise
示例#6
0
    def log_to_tensorboard(self, epoch: int) -> None:
        def none_to_zero(x: Optional[float]) -> float:
            if x is None or math.isnan(x):
                return 0.0
            return x

        for name, value in [
            ("Training/td_loss", self.get_recent_td_loss()),
            ("Training/reward_loss", self.get_recent_reward_loss()),
            ("Training/imitator_loss", self.get_recent_imitator_loss()),
        ]:
            SummaryWriterContext.add_scalar(name, none_to_zero(value), epoch)
示例#7
0
    def write_summary(self, actions: List[str]):
        if actions:
            for field, log_key in [
                ("logged_actions", "actions/logged"),
                ("model_action_idxs", "actions/model"),
            ]:
                val = getattr(self, field)
                if val is None:
                    continue
                for i, action in enumerate(actions):
                    SummaryWriterContext.add_scalar(
                        "{}/{}".format(log_key, action), (val == i).sum().item()
                    )

        for field, log_key in [
            ("td_loss", "td_loss"),
            ("imitator_loss", "imitator_loss"),
            ("reward_loss", "reward_loss"),
            ("logged_propensities", "propensities/logged"),
            ("logged_rewards", "reward/logged"),
            ("logged_values", "value/logged"),
            ("model_values_on_logged_actions", "value/model_logged_action"),
        ]:
            val = getattr(self, field)
            if val is None:
                continue
            assert len(val.shape) == 1 or (
                len(val.shape) == 2 and val.shape[1] == 1
            ), "Unexpected shape for {}: {}".format(field, val.shape)
            self._log_histogram_and_mean(log_key, val)

        for field, log_key in [
            ("model_propensities", "propensities/model"),
            ("model_rewards", "reward/model"),
            ("model_values", "value/model"),
        ]:
            val = getattr(self, field)
            if val is None:
                continue
            if (
                len(val.shape) == 1 or (len(val.shape) == 2 and val.shape[1] == 1)
            ) and not actions:
                self._log_histogram_and_mean(log_key, val)
            elif len(val.shape) == 2 and val.shape[1] == len(actions):
                for i, action in enumerate(actions):
                    self._log_histogram_and_mean(f"{log_key}/{action}", val[:, i])
            else:
                raise ValueError(
                    "Unexpected shape for {}: {}; actions: {}".format(
                        field, val.shape, actions
                    )
                )
示例#8
0
 def test_global_step(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock()
         with summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
             SummaryWriterContext.increase_global_step()
             SummaryWriterContext.add_scalar("test", torch.zeros(1))
         writer.add_scalar.assert_has_calls([
             call("test", torch.ones(1), global_step=0),
             call("test", torch.zeros(1), global_step=1),
         ])
         self.assertEqual(2, len(writer.add_scalar.mock_calls))
示例#9
0
 def test_writing_stack(self):
     with TemporaryDirectory() as tmp_dir1, TemporaryDirectory(
     ) as tmp_dir2:
         writer1 = SummaryWriter(tmp_dir1)
         writer1.add_scalar = MagicMock()
         writer2 = SummaryWriter(tmp_dir2)
         writer2.add_scalar = MagicMock()
         with summary_writer_context(writer1):
             with summary_writer_context(writer2):
                 SummaryWriterContext.add_scalar("test2", torch.ones(1))
             SummaryWriterContext.add_scalar("test1", torch.zeros(1))
         writer1.add_scalar.assert_called_once_with("test1",
                                                    torch.zeros(1),
                                                    global_step=0)
         writer2.add_scalar.assert_called_once_with("test2",
                                                    torch.ones(1),
                                                    global_step=0)
示例#10
0
def _log_histogram_and_mean(name, key, x):
    SummaryWriterContext.add_histogram(f"dueling_network/{name}/{key}",
                                       x.detach().cpu())
    SummaryWriterContext.add_scalar(f"dueling_network/{name}/mean_{key}",
                                    x.detach().mean().cpu())
示例#11
0
 def aggregate(self, values):
     for i, action in enumerate(self.actions):
         SummaryWriterContext.add_scalar(f"{self.log_key}/{action}",
                                         (values == i).sum().item())
示例#12
0
 def update(self, key: str, value):
     SummaryWriterContext.add_scalar(self.logging_key, value)
示例#13
0
 def test_with_none(self):
     with summary_writer_context(None):
         self.assertIsNone(
             SummaryWriterContext.add_scalar("test", torch.ones(1)))
示例#14
0
 def test_noop(self):
     self.assertIsNone(
         SummaryWriterContext.add_scalar("test", torch.ones(1)))