def training_step(self, batch, batch_idx: int, optimizer_idx: int = 0): assert (optimizer_idx == 0) or (self._num_optimizing_steps > 1) if optimizer_idx == 0: self.batches_processed_this_epoch += 1 self.all_batches_processed += 1 if self._training_step_generator is None: if self._training_batch_type and isinstance(batch, dict): batch = self._training_batch_type.from_dict(batch) self._training_step_generator = self.train_step_gen( batch, batch_idx) ret = next(self._training_step_generator) if optimizer_idx == self._num_optimizing_steps - 1: if not self._verified_steps: try: next(self._training_step_generator) except StopIteration: self._verified_steps = True if not self._verified_steps: raise RuntimeError( "training_step_gen() yields too many times." "The number of yields should match the number of optimizers," f" in this case {self._num_optimizing_steps}") self._training_step_generator = None SummaryWriterContext.increase_global_step() return ret
def test_global_step(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_scalar = MagicMock() with summary_writer_context(writer): SummaryWriterContext.add_scalar("test", torch.ones(1)) SummaryWriterContext.increase_global_step() SummaryWriterContext.add_scalar("test", torch.zeros(1)) writer.add_scalar.assert_has_calls([ call("test", torch.ones(1), global_step=0), call("test", torch.zeros(1), global_step=1), ]) self.assertEqual(2, len(writer.add_scalar.mock_calls))
def run_episode(env: EnvWrapper, agent: Agent, mdp_id: int = 0, max_steps: Optional[int] = None) -> Trajectory: """ Return sum of rewards from episode. After max_steps (if specified), the environment is assumed to be terminal. Can also specify the mdp_id and gamma of episode. """ trajectory = Trajectory() obs = env.reset() possible_actions_mask = env.possible_actions_mask terminal = False num_steps = 0 while not terminal: action, log_prob = agent.act(obs, possible_actions_mask) next_obs, reward, terminal, _ = env.step(action) next_possible_actions_mask = env.possible_actions_mask if max_steps is not None and num_steps >= (max_steps - 1): terminal = True # Only partially filled. Agent can fill in more fields. transition = Transition( mdp_id=mdp_id, sequence_number=num_steps, observation=obs, action=action, reward=float(reward), terminal=bool(terminal), log_prob=log_prob, possible_actions_mask=possible_actions_mask, ) agent.post_step(transition) trajectory.add_transition(transition) SummaryWriterContext.increase_global_step() obs = next_obs possible_actions_mask = next_possible_actions_mask num_steps += 1 agent.post_episode(trajectory) return trajectory