def training_step(self, batch, batch_idx: int, optimizer_idx: int = 0): assert (optimizer_idx == 0) or (self._num_optimizing_steps > 1) if optimizer_idx == 0: self.batches_processed_this_epoch += 1 self.all_batches_processed += 1 if self._training_step_generator is None: if self._training_batch_type and isinstance(batch, dict): batch = self._training_batch_type.from_dict(batch) self._training_step_generator = self.train_step_gen( batch, batch_idx) ret = next(self._training_step_generator) if optimizer_idx == self._num_optimizing_steps - 1: if not self._verified_steps: try: next(self._training_step_generator) except StopIteration: self._verified_steps = True if not self._verified_steps: raise RuntimeError( "training_step_gen() yields too many times." "The number of yields should match the number of optimizers," f" in this case {self._num_optimizing_steps}") self._training_step_generator = None SummaryWriterContext.increase_global_step() return ret
def test_add_custom_scalars(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_custom_scalars = MagicMock() with summary_writer_context(writer): SummaryWriterContext.add_custom_scalars_multilinechart( ["a", "b"], category="cat", title="title") with self.assertRaisesRegex( AssertionError, "Title \\(title\\) is already in category \\(cat\\)"): SummaryWriterContext.add_custom_scalars_multilinechart( ["c", "d"], category="cat", title="title") SummaryWriterContext.add_custom_scalars_multilinechart( ["e", "f"], category="cat", title="title2") SummaryWriterContext.add_custom_scalars_multilinechart( ["g", "h"], category="cat2", title="title") SummaryWriterContext.add_custom_scalars(writer) writer.add_custom_scalars.assert_called_once_with({ "cat": { "title": ["Multiline", ["a", "b"]], "title2": ["Multiline", ["e", "f"]], }, "cat2": { "title": ["Multiline", ["g", "h"]] }, })
def log_to_tensorboard(self, metric_name: str) -> None: self.check_estimates_exist() def none_to_zero(x: Optional[float]) -> float: if x is None or math.isnan(x): return 0.0 return x for name, value in [ ( "CPE/{}/Direct_Method_Reward".format(metric_name), # pyre-fixme[16]: `Optional` has no attribute `normalized`. self.direct_method.normalized, ), ( "CPE/{}/IPS_Reward".format(metric_name), self.inverse_propensity.normalized, ), ( "CPE/{}/Doubly_Robust_Reward".format(metric_name), self.doubly_robust.normalized, ), ( "CPE/{}/Sequential_Doubly_Robust".format(metric_name), self.sequential_doubly_robust.normalized, ), ( "CPE/{}/Weighted_Sequential_Doubly_Robust".format(metric_name), self.weighted_doubly_robust.normalized, ), ("CPE/{}/MAGIC".format(metric_name), self.magic.normalized), ]: SummaryWriterContext.add_scalar(name, none_to_zero(value))
def test_swallowing_exception(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_scalar = MagicMock( side_effect=NotImplementedError("test")) writer.exceptions_to_ignore = (NotImplementedError, KeyError) with summary_writer_context(writer): SummaryWriterContext.add_scalar("test", torch.ones(1))
def test_not_swallowing_exception(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_scalar = MagicMock( side_effect=NotImplementedError("test")) with self.assertRaisesRegex( NotImplementedError, "test"), summary_writer_context(writer): SummaryWriterContext.add_scalar("test", torch.ones(1))
def test_writing(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_scalar = MagicMock() with summary_writer_context(writer): SummaryWriterContext.add_scalar("test", torch.ones(1)) writer.add_scalar.assert_called_once_with("test", torch.ones(1), global_step=0)
def __init__(self, key: str, title: str, actions: List[str]): super().__init__(key) self.log_key = f"actions/{title}" self.actions = actions SummaryWriterContext.add_custom_scalars_multilinechart( [f"{self.log_key}/{action_name}" for action_name in actions], category="actions", title=title, )
def _log_histogram_and_mean(log_key, val): try: SummaryWriterContext.add_histogram(log_key, val) SummaryWriterContext.add_scalar(f"{log_key}/mean", val.mean()) except ValueError: logger.warning( f"Cannot create histogram for key: {log_key}; " "this is likely because you have NULL value in your input; " f"value: {val}") raise
def log_to_tensorboard(self, epoch: int) -> None: def none_to_zero(x: Optional[float]) -> float: if x is None or math.isnan(x): return 0.0 return x for name, value in [ ("Training/td_loss", self.get_recent_td_loss()), ("Training/reward_loss", self.get_recent_reward_loss()), ("Training/imitator_loss", self.get_recent_imitator_loss()), ]: SummaryWriterContext.add_scalar(name, none_to_zero(value), epoch)
def write_summary(self, actions: List[str]): if actions: for field, log_key in [ ("logged_actions", "actions/logged"), ("model_action_idxs", "actions/model"), ]: val = getattr(self, field) if val is None: continue for i, action in enumerate(actions): SummaryWriterContext.add_scalar( "{}/{}".format(log_key, action), (val == i).sum().item() ) for field, log_key in [ ("td_loss", "td_loss"), ("imitator_loss", "imitator_loss"), ("reward_loss", "reward_loss"), ("logged_propensities", "propensities/logged"), ("logged_rewards", "reward/logged"), ("logged_values", "value/logged"), ("model_values_on_logged_actions", "value/model_logged_action"), ]: val = getattr(self, field) if val is None: continue assert len(val.shape) == 1 or ( len(val.shape) == 2 and val.shape[1] == 1 ), "Unexpected shape for {}: {}".format(field, val.shape) self._log_histogram_and_mean(log_key, val) for field, log_key in [ ("model_propensities", "propensities/model"), ("model_rewards", "reward/model"), ("model_values", "value/model"), ]: val = getattr(self, field) if val is None: continue if ( len(val.shape) == 1 or (len(val.shape) == 2 and val.shape[1] == 1) ) and not actions: self._log_histogram_and_mean(log_key, val) elif len(val.shape) == 2 and val.shape[1] == len(actions): for i, action in enumerate(actions): self._log_histogram_and_mean(f"{log_key}/{action}", val[:, i]) else: raise ValueError( "Unexpected shape for {}: {}; actions: {}".format( field, val.shape, actions ) )
def get_log_prob(self, state: rlt.FeatureData, squashed_action: torch.Tensor): """ Action is expected to be squashed with tanh """ if self.use_l2_normalization: # TODO: calculate log_prob for l2 normalization # https://math.stackexchange.com/questions/3120506/on-the-distribution-of-a-normalized-gaussian-vector # http://proceedings.mlr.press/v100/mazoure20a/mazoure20a.pdf pass loc, scale_log = self._get_loc_and_scale_log(state) raw_action = torch.atanh(squashed_action) r = (raw_action - loc) / scale_log.exp() log_prob = self._normal_log_prob(r, scale_log) squash_correction = self._squash_correction(squashed_action) if SummaryWriterContext._global_step % 1000 == 0: SummaryWriterContext.add_histogram("actor/get_log_prob/loc", loc.detach().cpu()) SummaryWriterContext.add_histogram("actor/get_log_prob/scale_log", scale_log.detach().cpu()) SummaryWriterContext.add_histogram("actor/get_log_prob/log_prob", log_prob.detach().cpu()) SummaryWriterContext.add_histogram( "actor/get_log_prob/squash_correction", squash_correction.detach().cpu()) return torch.sum(log_prob - squash_correction, dim=1).reshape(-1, 1)
def add_custom_scalars(action_names: Optional[List[str]]): if not action_names: return SummaryWriterContext.add_custom_scalars_multilinechart( [ "propensities/model/{}/mean".format(action_name) for action_name in action_names ], category="propensities", title="model", ) SummaryWriterContext.add_custom_scalars_multilinechart( [ "propensities/logged/{}/mean".format(action_name) for action_name in action_names ], category="propensities", title="logged", ) SummaryWriterContext.add_custom_scalars_multilinechart( ["actions/logged/{}".format(action_name) for action_name in action_names], category="actions", title="logged", ) SummaryWriterContext.add_custom_scalars_multilinechart( ["actions/model/{}".format(action_name) for action_name in action_names], category="actions", title="model", )
def forward(self, state: rlt.FeatureData): loc, scale_log = self._get_loc_and_scale_log(state) r = torch.randn_like(scale_log, device=scale_log.device) raw_action = loc + r * scale_log.exp() squashed_action = self._squash_raw_action(raw_action) squashed_loc = self._squash_raw_action(loc) if SummaryWriterContext._global_step % 1000 == 0: SummaryWriterContext.add_histogram("actor/forward/loc", loc.detach().cpu()) SummaryWriterContext.add_histogram("actor/forward/scale_log", scale_log.detach().cpu()) return rlt.ActorOutput( action=squashed_action, log_prob=self.get_log_prob(state, squashed_action), squashed_mean=squashed_loc, )
def test_writing_stack(self): with TemporaryDirectory() as tmp_dir1, TemporaryDirectory( ) as tmp_dir2: writer1 = SummaryWriter(tmp_dir1) writer1.add_scalar = MagicMock() writer2 = SummaryWriter(tmp_dir2) writer2.add_scalar = MagicMock() with summary_writer_context(writer1): with summary_writer_context(writer2): SummaryWriterContext.add_scalar("test2", torch.ones(1)) SummaryWriterContext.add_scalar("test1", torch.zeros(1)) writer1.add_scalar.assert_called_once_with("test1", torch.zeros(1), global_step=0) writer2.add_scalar.assert_called_once_with("test2", torch.ones(1), global_step=0)
def __init__( self, key: str, category: str, title: str, actions: List[str], log_key_prefix: Optional[str] = None, ): super().__init__(key) self.log_key_prefix = log_key_prefix or f"{category}/{title}" self.actions = actions SummaryWriterContext.add_custom_scalars_multilinechart( [ f"{self.log_key_prefix}/{action_name}/mean" for action_name in actions ], category=category, title=title, )
def run_episode(env: EnvWrapper, agent: Agent, mdp_id: int = 0, max_steps: Optional[int] = None) -> Trajectory: """ Return sum of rewards from episode. After max_steps (if specified), the environment is assumed to be terminal. Can also specify the mdp_id and gamma of episode. """ trajectory = Trajectory() obs = env.reset() possible_actions_mask = env.possible_actions_mask terminal = False num_steps = 0 while not terminal: action, log_prob = agent.act(obs, possible_actions_mask) next_obs, reward, terminal, _ = env.step(action) next_possible_actions_mask = env.possible_actions_mask if max_steps is not None and num_steps >= (max_steps - 1): terminal = True # Only partially filled. Agent can fill in more fields. transition = Transition( mdp_id=mdp_id, sequence_number=num_steps, observation=obs, action=action, reward=float(reward), terminal=bool(terminal), log_prob=log_prob, possible_actions_mask=possible_actions_mask, ) agent.post_step(transition) trajectory.add_transition(transition) SummaryWriterContext.increase_global_step() obs = next_obs possible_actions_mask = next_possible_actions_mask num_steps += 1 agent.post_episode(trajectory) return trajectory
def test_global_step(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_scalar = MagicMock() with summary_writer_context(writer): SummaryWriterContext.add_scalar("test", torch.ones(1)) SummaryWriterContext.increase_global_step() SummaryWriterContext.add_scalar("test", torch.zeros(1)) writer.add_scalar.assert_has_calls([ call("test", torch.ones(1), global_step=0), call("test", torch.zeros(1), global_step=1), ]) self.assertEqual(2, len(writer.add_scalar.mock_calls))
def _log_histogram_and_mean(name, key, x): SummaryWriterContext.add_histogram(f"dueling_network/{name}/{key}", x.detach().cpu()) SummaryWriterContext.add_scalar(f"dueling_network/{name}/mean_{key}", x.detach().mean().cpu())
def tearDown(self): SummaryWriterContext._reset_globals()
def setUp(self): SummaryWriterContext._reset_globals() logging.basicConfig(level=logging.INFO) np.random.seed(SEED) torch.manual_seed(SEED) random.seed(SEED)
def aggregate(self, values): for i, action in enumerate(self.actions): SummaryWriterContext.add_scalar(f"{self.log_key}/{action}", (values == i).sum().item())
def test_noop(self): self.assertIsNone( SummaryWriterContext.add_scalar("test", torch.ones(1)))
def test_with_none(self): with summary_writer_context(None): self.assertIsNone( SummaryWriterContext.add_scalar("test", torch.ones(1)))
def update(self, key: str, value): SummaryWriterContext.add_scalar(self.logging_key, value)
def test_swallowing_histogram_value_error(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) with summary_writer_context(writer): SummaryWriterContext.add_histogram("bad_histogram", torch.ones(100, 1))