def training_step(self, batch, batch_idx: int, optimizer_idx: int = 0):
        assert (optimizer_idx == 0) or (self._num_optimizing_steps > 1)

        if optimizer_idx == 0:
            self.batches_processed_this_epoch += 1
            self.all_batches_processed += 1
        if self._training_step_generator is None:
            if self._training_batch_type and isinstance(batch, dict):
                batch = self._training_batch_type.from_dict(batch)
            self._training_step_generator = self.train_step_gen(
                batch, batch_idx)

        ret = next(self._training_step_generator)

        if optimizer_idx == self._num_optimizing_steps - 1:
            if not self._verified_steps:
                try:
                    next(self._training_step_generator)
                except StopIteration:
                    self._verified_steps = True
                if not self._verified_steps:
                    raise RuntimeError(
                        "training_step_gen() yields too many times."
                        "The number of yields should match the number of optimizers,"
                        f" in this case {self._num_optimizing_steps}")
            self._training_step_generator = None
            SummaryWriterContext.increase_global_step()

        return ret
示例#2
0
 def test_global_step(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock()
         with summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
             SummaryWriterContext.increase_global_step()
             SummaryWriterContext.add_scalar("test", torch.zeros(1))
         writer.add_scalar.assert_has_calls([
             call("test", torch.ones(1), global_step=0),
             call("test", torch.zeros(1), global_step=1),
         ])
         self.assertEqual(2, len(writer.add_scalar.mock_calls))
示例#3
0
def run_episode(env: EnvWrapper,
                agent: Agent,
                mdp_id: int = 0,
                max_steps: Optional[int] = None) -> Trajectory:
    """
    Return sum of rewards from episode.
    After max_steps (if specified), the environment is assumed to be terminal.
    Can also specify the mdp_id and gamma of episode.
    """
    trajectory = Trajectory()
    obs = env.reset()
    possible_actions_mask = env.possible_actions_mask
    terminal = False
    num_steps = 0
    while not terminal:
        action, log_prob = agent.act(obs, possible_actions_mask)
        next_obs, reward, terminal, _ = env.step(action)
        next_possible_actions_mask = env.possible_actions_mask
        if max_steps is not None and num_steps >= (max_steps - 1):
            terminal = True

        # Only partially filled. Agent can fill in more fields.
        transition = Transition(
            mdp_id=mdp_id,
            sequence_number=num_steps,
            observation=obs,
            action=action,
            reward=float(reward),
            terminal=bool(terminal),
            log_prob=log_prob,
            possible_actions_mask=possible_actions_mask,
        )
        agent.post_step(transition)
        trajectory.add_transition(transition)
        SummaryWriterContext.increase_global_step()
        obs = next_obs
        possible_actions_mask = next_possible_actions_mask
        num_steps += 1
    agent.post_episode(trajectory)
    return trajectory