def training_step(self, batch, batch_idx: int, optimizer_idx: int = 0):
        assert (optimizer_idx == 0) or (self._num_optimizing_steps > 1)

        if optimizer_idx == 0:
            self.batches_processed_this_epoch += 1
            self.all_batches_processed += 1
        if self._training_step_generator is None:
            if self._training_batch_type and isinstance(batch, dict):
                batch = self._training_batch_type.from_dict(batch)
            self._training_step_generator = self.train_step_gen(
                batch, batch_idx)

        ret = next(self._training_step_generator)

        if optimizer_idx == self._num_optimizing_steps - 1:
            if not self._verified_steps:
                try:
                    next(self._training_step_generator)
                except StopIteration:
                    self._verified_steps = True
                if not self._verified_steps:
                    raise RuntimeError(
                        "training_step_gen() yields too many times."
                        "The number of yields should match the number of optimizers,"
                        f" in this case {self._num_optimizing_steps}")
            self._training_step_generator = None
            SummaryWriterContext.increase_global_step()

        return ret
Exemplo n.º 2
0
    def test_add_custom_scalars(self):
        with TemporaryDirectory() as tmp_dir:
            writer = SummaryWriter(tmp_dir)
            writer.add_custom_scalars = MagicMock()
            with summary_writer_context(writer):
                SummaryWriterContext.add_custom_scalars_multilinechart(
                    ["a", "b"], category="cat", title="title")
                with self.assertRaisesRegex(
                        AssertionError,
                        "Title \\(title\\) is already in category \\(cat\\)"):
                    SummaryWriterContext.add_custom_scalars_multilinechart(
                        ["c", "d"], category="cat", title="title")
                SummaryWriterContext.add_custom_scalars_multilinechart(
                    ["e", "f"], category="cat", title="title2")
                SummaryWriterContext.add_custom_scalars_multilinechart(
                    ["g", "h"], category="cat2", title="title")

            SummaryWriterContext.add_custom_scalars(writer)
            writer.add_custom_scalars.assert_called_once_with({
                "cat": {
                    "title": ["Multiline", ["a", "b"]],
                    "title2": ["Multiline", ["e", "f"]],
                },
                "cat2": {
                    "title": ["Multiline", ["g", "h"]]
                },
            })
Exemplo n.º 3
0
    def log_to_tensorboard(self, metric_name: str) -> None:
        self.check_estimates_exist()

        def none_to_zero(x: Optional[float]) -> float:
            if x is None or math.isnan(x):
                return 0.0
            return x

        for name, value in [
            (
                "CPE/{}/Direct_Method_Reward".format(metric_name),
                # pyre-fixme[16]: `Optional` has no attribute `normalized`.
                self.direct_method.normalized,
            ),
            (
                "CPE/{}/IPS_Reward".format(metric_name),
                self.inverse_propensity.normalized,
            ),
            (
                "CPE/{}/Doubly_Robust_Reward".format(metric_name),
                self.doubly_robust.normalized,
            ),
            (
                "CPE/{}/Sequential_Doubly_Robust".format(metric_name),
                self.sequential_doubly_robust.normalized,
            ),
            (
                "CPE/{}/Weighted_Sequential_Doubly_Robust".format(metric_name),
                self.weighted_doubly_robust.normalized,
            ),
            ("CPE/{}/MAGIC".format(metric_name), self.magic.normalized),
        ]:
            SummaryWriterContext.add_scalar(name, none_to_zero(value))
Exemplo n.º 4
0
 def test_swallowing_exception(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock(
             side_effect=NotImplementedError("test"))
         writer.exceptions_to_ignore = (NotImplementedError, KeyError)
         with summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
Exemplo n.º 5
0
 def test_not_swallowing_exception(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock(
             side_effect=NotImplementedError("test"))
         with self.assertRaisesRegex(
                 NotImplementedError,
                 "test"), summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
Exemplo n.º 6
0
 def test_writing(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock()
         with summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
         writer.add_scalar.assert_called_once_with("test",
                                                   torch.ones(1),
                                                   global_step=0)
Exemplo n.º 7
0
 def __init__(self, key: str, title: str, actions: List[str]):
     super().__init__(key)
     self.log_key = f"actions/{title}"
     self.actions = actions
     SummaryWriterContext.add_custom_scalars_multilinechart(
         [f"{self.log_key}/{action_name}" for action_name in actions],
         category="actions",
         title=title,
     )
Exemplo n.º 8
0
def _log_histogram_and_mean(log_key, val):
    try:
        SummaryWriterContext.add_histogram(log_key, val)
        SummaryWriterContext.add_scalar(f"{log_key}/mean", val.mean())
    except ValueError:
        logger.warning(
            f"Cannot create histogram for key: {log_key}; "
            "this is likely because you have NULL value in your input; "
            f"value: {val}")
        raise
Exemplo n.º 9
0
    def log_to_tensorboard(self, epoch: int) -> None:
        def none_to_zero(x: Optional[float]) -> float:
            if x is None or math.isnan(x):
                return 0.0
            return x

        for name, value in [
            ("Training/td_loss", self.get_recent_td_loss()),
            ("Training/reward_loss", self.get_recent_reward_loss()),
            ("Training/imitator_loss", self.get_recent_imitator_loss()),
        ]:
            SummaryWriterContext.add_scalar(name, none_to_zero(value), epoch)
Exemplo n.º 10
0
    def write_summary(self, actions: List[str]):
        if actions:
            for field, log_key in [
                ("logged_actions", "actions/logged"),
                ("model_action_idxs", "actions/model"),
            ]:
                val = getattr(self, field)
                if val is None:
                    continue
                for i, action in enumerate(actions):
                    SummaryWriterContext.add_scalar(
                        "{}/{}".format(log_key, action), (val == i).sum().item()
                    )

        for field, log_key in [
            ("td_loss", "td_loss"),
            ("imitator_loss", "imitator_loss"),
            ("reward_loss", "reward_loss"),
            ("logged_propensities", "propensities/logged"),
            ("logged_rewards", "reward/logged"),
            ("logged_values", "value/logged"),
            ("model_values_on_logged_actions", "value/model_logged_action"),
        ]:
            val = getattr(self, field)
            if val is None:
                continue
            assert len(val.shape) == 1 or (
                len(val.shape) == 2 and val.shape[1] == 1
            ), "Unexpected shape for {}: {}".format(field, val.shape)
            self._log_histogram_and_mean(log_key, val)

        for field, log_key in [
            ("model_propensities", "propensities/model"),
            ("model_rewards", "reward/model"),
            ("model_values", "value/model"),
        ]:
            val = getattr(self, field)
            if val is None:
                continue
            if (
                len(val.shape) == 1 or (len(val.shape) == 2 and val.shape[1] == 1)
            ) and not actions:
                self._log_histogram_and_mean(log_key, val)
            elif len(val.shape) == 2 and val.shape[1] == len(actions):
                for i, action in enumerate(actions):
                    self._log_histogram_and_mean(f"{log_key}/{action}", val[:, i])
            else:
                raise ValueError(
                    "Unexpected shape for {}: {}; actions: {}".format(
                        field, val.shape, actions
                    )
                )
Exemplo n.º 11
0
    def get_log_prob(self, state: rlt.FeatureData,
                     squashed_action: torch.Tensor):
        """
        Action is expected to be squashed with tanh
        """
        if self.use_l2_normalization:
            # TODO: calculate log_prob for l2 normalization
            # https://math.stackexchange.com/questions/3120506/on-the-distribution-of-a-normalized-gaussian-vector
            # http://proceedings.mlr.press/v100/mazoure20a/mazoure20a.pdf
            pass

        loc, scale_log = self._get_loc_and_scale_log(state)
        raw_action = torch.atanh(squashed_action)
        r = (raw_action - loc) / scale_log.exp()
        log_prob = self._normal_log_prob(r, scale_log)
        squash_correction = self._squash_correction(squashed_action)
        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram("actor/get_log_prob/loc",
                                               loc.detach().cpu())
            SummaryWriterContext.add_histogram("actor/get_log_prob/scale_log",
                                               scale_log.detach().cpu())
            SummaryWriterContext.add_histogram("actor/get_log_prob/log_prob",
                                               log_prob.detach().cpu())
            SummaryWriterContext.add_histogram(
                "actor/get_log_prob/squash_correction",
                squash_correction.detach().cpu())
        return torch.sum(log_prob - squash_correction, dim=1).reshape(-1, 1)
Exemplo n.º 12
0
    def add_custom_scalars(action_names: Optional[List[str]]):
        if not action_names:
            return

        SummaryWriterContext.add_custom_scalars_multilinechart(
            [
                "propensities/model/{}/mean".format(action_name)
                for action_name in action_names
            ],
            category="propensities",
            title="model",
        )
        SummaryWriterContext.add_custom_scalars_multilinechart(
            [
                "propensities/logged/{}/mean".format(action_name)
                for action_name in action_names
            ],
            category="propensities",
            title="logged",
        )
        SummaryWriterContext.add_custom_scalars_multilinechart(
            ["actions/logged/{}".format(action_name) for action_name in action_names],
            category="actions",
            title="logged",
        )
        SummaryWriterContext.add_custom_scalars_multilinechart(
            ["actions/model/{}".format(action_name) for action_name in action_names],
            category="actions",
            title="model",
        )
Exemplo n.º 13
0
    def forward(self, state: rlt.FeatureData):
        loc, scale_log = self._get_loc_and_scale_log(state)
        r = torch.randn_like(scale_log, device=scale_log.device)
        raw_action = loc + r * scale_log.exp()
        squashed_action = self._squash_raw_action(raw_action)
        squashed_loc = self._squash_raw_action(loc)
        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram("actor/forward/loc",
                                               loc.detach().cpu())
            SummaryWriterContext.add_histogram("actor/forward/scale_log",
                                               scale_log.detach().cpu())

        return rlt.ActorOutput(
            action=squashed_action,
            log_prob=self.get_log_prob(state, squashed_action),
            squashed_mean=squashed_loc,
        )
Exemplo n.º 14
0
 def test_writing_stack(self):
     with TemporaryDirectory() as tmp_dir1, TemporaryDirectory(
     ) as tmp_dir2:
         writer1 = SummaryWriter(tmp_dir1)
         writer1.add_scalar = MagicMock()
         writer2 = SummaryWriter(tmp_dir2)
         writer2.add_scalar = MagicMock()
         with summary_writer_context(writer1):
             with summary_writer_context(writer2):
                 SummaryWriterContext.add_scalar("test2", torch.ones(1))
             SummaryWriterContext.add_scalar("test1", torch.zeros(1))
         writer1.add_scalar.assert_called_once_with("test1",
                                                    torch.zeros(1),
                                                    global_step=0)
         writer2.add_scalar.assert_called_once_with("test2",
                                                    torch.ones(1),
                                                    global_step=0)
Exemplo n.º 15
0
 def __init__(
     self,
     key: str,
     category: str,
     title: str,
     actions: List[str],
     log_key_prefix: Optional[str] = None,
 ):
     super().__init__(key)
     self.log_key_prefix = log_key_prefix or f"{category}/{title}"
     self.actions = actions
     SummaryWriterContext.add_custom_scalars_multilinechart(
         [
             f"{self.log_key_prefix}/{action_name}/mean"
             for action_name in actions
         ],
         category=category,
         title=title,
     )
Exemplo n.º 16
0
def run_episode(env: EnvWrapper,
                agent: Agent,
                mdp_id: int = 0,
                max_steps: Optional[int] = None) -> Trajectory:
    """
    Return sum of rewards from episode.
    After max_steps (if specified), the environment is assumed to be terminal.
    Can also specify the mdp_id and gamma of episode.
    """
    trajectory = Trajectory()
    obs = env.reset()
    possible_actions_mask = env.possible_actions_mask
    terminal = False
    num_steps = 0
    while not terminal:
        action, log_prob = agent.act(obs, possible_actions_mask)
        next_obs, reward, terminal, _ = env.step(action)
        next_possible_actions_mask = env.possible_actions_mask
        if max_steps is not None and num_steps >= (max_steps - 1):
            terminal = True

        # Only partially filled. Agent can fill in more fields.
        transition = Transition(
            mdp_id=mdp_id,
            sequence_number=num_steps,
            observation=obs,
            action=action,
            reward=float(reward),
            terminal=bool(terminal),
            log_prob=log_prob,
            possible_actions_mask=possible_actions_mask,
        )
        agent.post_step(transition)
        trajectory.add_transition(transition)
        SummaryWriterContext.increase_global_step()
        obs = next_obs
        possible_actions_mask = next_possible_actions_mask
        num_steps += 1
    agent.post_episode(trajectory)
    return trajectory
Exemplo n.º 17
0
 def test_global_step(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock()
         with summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
             SummaryWriterContext.increase_global_step()
             SummaryWriterContext.add_scalar("test", torch.zeros(1))
         writer.add_scalar.assert_has_calls([
             call("test", torch.ones(1), global_step=0),
             call("test", torch.zeros(1), global_step=1),
         ])
         self.assertEqual(2, len(writer.add_scalar.mock_calls))
Exemplo n.º 18
0
def _log_histogram_and_mean(name, key, x):
    SummaryWriterContext.add_histogram(f"dueling_network/{name}/{key}",
                                       x.detach().cpu())
    SummaryWriterContext.add_scalar(f"dueling_network/{name}/mean_{key}",
                                    x.detach().mean().cpu())
Exemplo n.º 19
0
 def tearDown(self):
     SummaryWriterContext._reset_globals()
Exemplo n.º 20
0
 def setUp(self):
     SummaryWriterContext._reset_globals()
     logging.basicConfig(level=logging.INFO)
     np.random.seed(SEED)
     torch.manual_seed(SEED)
     random.seed(SEED)
Exemplo n.º 21
0
 def aggregate(self, values):
     for i, action in enumerate(self.actions):
         SummaryWriterContext.add_scalar(f"{self.log_key}/{action}",
                                         (values == i).sum().item())
Exemplo n.º 22
0
 def test_noop(self):
     self.assertIsNone(
         SummaryWriterContext.add_scalar("test", torch.ones(1)))
Exemplo n.º 23
0
 def test_with_none(self):
     with summary_writer_context(None):
         self.assertIsNone(
             SummaryWriterContext.add_scalar("test", torch.ones(1)))
Exemplo n.º 24
0
 def update(self, key: str, value):
     SummaryWriterContext.add_scalar(self.logging_key, value)
Exemplo n.º 25
0
 def test_swallowing_histogram_value_error(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         with summary_writer_context(writer):
             SummaryWriterContext.add_histogram("bad_histogram",
                                                torch.ones(100, 1))