コード例 #1
0
ファイル: gymrunner.py プロジェクト: lwzbuaa/ReAgent
def run_episode(env: Env,
                agent: Agent,
                mdp_id: int = 0,
                max_steps: Optional[int] = None) -> Trajectory:
    """
    Return sum of rewards from episode.
    After max_steps (if specified), the environment is assumed to be terminal.
    Can also specify the mdp_id and gamma of episode.
    """
    trajectory = Trajectory()
    obs = env.reset()
    terminal = False
    num_steps = 0
    while not terminal:
        action = agent.act(obs)
        next_obs, reward, terminal, _ = env.step(action)
        if max_steps is not None and num_steps >= max_steps:
            terminal = True

        # Only partially filled. Agent can fill in more fields.
        transition = Transition(
            mdp_id=mdp_id,
            sequence_number=num_steps,
            observation=obs,
            action=action,
            reward=reward,
            terminal=terminal,
        )
        agent.post_step(transition)
        trajectory.add_transition(transition)
        SummaryWriterContext.increase_global_step()
        obs = next_obs
        num_steps += 1
    return trajectory
コード例 #2
0
    def test_add_custom_scalars(self):
        with TemporaryDirectory() as tmp_dir:
            writer = SummaryWriter(tmp_dir)
            writer.add_custom_scalars = MagicMock()
            with summary_writer_context(writer):
                SummaryWriterContext.add_custom_scalars_multilinechart(
                    ["a", "b"], category="cat", title="title")
                with self.assertRaisesRegex(
                        AssertionError,
                        "Title \\(title\\) is already in category \\(cat\\)"):
                    SummaryWriterContext.add_custom_scalars_multilinechart(
                        ["c", "d"], category="cat", title="title")
                SummaryWriterContext.add_custom_scalars_multilinechart(
                    ["e", "f"], category="cat", title="title2")
                SummaryWriterContext.add_custom_scalars_multilinechart(
                    ["g", "h"], category="cat2", title="title")

            SummaryWriterContext.add_custom_scalars(writer)
            writer.add_custom_scalars.assert_called_once_with({
                "cat": {
                    "title": ["Multiline", ["a", "b"]],
                    "title2": ["Multiline", ["e", "f"]],
                },
                "cat2": {
                    "title": ["Multiline", ["g", "h"]]
                },
            })
コード例 #3
0
    def training_step(self, batch, batch_idx: int, optimizer_idx: int = 0):
        assert (optimizer_idx == 0) or (self._num_optimizing_steps > 1)
        if self._training_step_generator is None:
            if self._training_batch_type and isinstance(batch, dict):
                batch = self._training_batch_type.from_dict(batch)
            self._training_step_generator = self.train_step_gen(
                batch, batch_idx)

        ret = next(self._training_step_generator)

        if optimizer_idx == self._num_optimizing_steps - 1:
            if not self._verified_steps:
                try:
                    next(self._training_step_generator)
                except StopIteration:
                    self._verified_steps = True
                if not self._verified_steps:
                    raise RuntimeError(
                        "training_step_gen() yields too many times."
                        "The number of yields should match the number of optimizers,"
                        f" in this case {self._num_optimizing_steps}")
            self._training_step_generator = None
            SummaryWriterContext.increase_global_step()

        return ret
コード例 #4
0
ファイル: cpe.py プロジェクト: wall-ed-coder/ReAgent
    def log_to_tensorboard(self, metric_name: str) -> None:
        self.check_estimates_exist()

        def none_to_zero(x: Optional[float]) -> float:
            if x is None or math.isnan(x):
                return 0.0
            return x

        for name, value in [
            (
                "CPE/{}/Direct_Method_Reward".format(metric_name),
                # pyre-fixme[16]: `Optional` has no attribute `normalized`.
                self.direct_method.normalized,
            ),
            (
                "CPE/{}/IPS_Reward".format(metric_name),
                self.inverse_propensity.normalized,
            ),
            (
                "CPE/{}/Doubly_Robust_Reward".format(metric_name),
                self.doubly_robust.normalized,
            ),
            (
                "CPE/{}/Sequential_Doubly_Robust".format(metric_name),
                self.sequential_doubly_robust.normalized,
            ),
            (
                "CPE/{}/Weighted_Sequential_Doubly_Robust".format(metric_name),
                self.weighted_doubly_robust.normalized,
            ),
            ("CPE/{}/MAGIC".format(metric_name), self.magic.normalized),
        ]:
            SummaryWriterContext.add_scalar(name, none_to_zero(value))
コード例 #5
0
 def test_swallowing_exception(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock(
             side_effect=NotImplementedError("test"))
         writer.exceptions_to_ignore = (NotImplementedError, KeyError)
         with summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
コード例 #6
0
 def test_not_swallowing_exception(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock(
             side_effect=NotImplementedError("test"))
         with self.assertRaisesRegex(
                 NotImplementedError,
                 "test"), summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
コード例 #7
0
 def __init__(self, key: str, title: str, actions: List[str]):
     super().__init__(key)
     self.log_key = f"actions/{title}"
     self.actions = actions
     SummaryWriterContext.add_custom_scalars_multilinechart(
         [f"{self.log_key}/{action_name}" for action_name in actions],
         category="actions",
         title=title,
     )
コード例 #8
0
 def test_writing(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock()
         with summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
         writer.add_scalar.assert_called_once_with("test",
                                                   torch.ones(1),
                                                   global_step=0)
コード例 #9
0
 def _log_histogram_and_mean(self, log_key, val):
     try:
         SummaryWriterContext.add_histogram(log_key, val)
         SummaryWriterContext.add_scalar(f"{log_key}/mean", val.mean())
     except ValueError:
         logger.warning(
             f"Cannot create histogram for key: {log_key}; "
             "this is likely because you have NULL value in your input; "
             f"value: {val}")
         raise
コード例 #10
0
    def write_summary(self, actions: List[str]):
        if actions:
            for field, log_key in [
                ("logged_actions", "actions/logged"),
                ("model_action_idxs", "actions/model"),
            ]:
                val = getattr(self, field)
                if val is None:
                    continue
                for i, action in enumerate(actions):
                    # pyre-fixme[16]: `SummaryWriterContext` has no attribute
                    #  `add_scalar`.
                    SummaryWriterContext.add_scalar(
                        "{}/{}".format(log_key, action), (val == i).sum().item()
                    )

        for field, log_key in [
            ("td_loss", "td_loss"),
            ("imitator_loss", "imitator_loss"),
            ("reward_loss", "reward_loss"),
            ("logged_propensities", "propensities/logged"),
            ("logged_rewards", "reward/logged"),
            ("logged_values", "value/logged"),
            ("model_values_on_logged_actions", "value/model_logged_action"),
        ]:
            val = getattr(self, field)
            if val is None:
                continue
            assert len(val.shape) == 1 or (
                len(val.shape) == 2 and val.shape[1] == 1
            ), "Unexpected shape for {}: {}".format(field, val.shape)
            self._log_histogram_and_mean(log_key, val)

        for field, log_key in [
            ("model_propensities", "propensities/model"),
            ("model_rewards", "reward/model"),
            ("model_values", "value/model"),
        ]:
            val = getattr(self, field)
            if val is None:
                continue
            if (
                len(val.shape) == 1 or (len(val.shape) == 2 and val.shape[1] == 1)
            ) and not actions:
                self._log_histogram_and_mean(log_key, val)
            elif len(val.shape) == 2 and val.shape[1] == len(actions):
                for i, action in enumerate(actions):
                    self._log_histogram_and_mean(f"{log_key}/{action}", val[:, i])
            else:
                raise ValueError(
                    "Unexpected shape for {}: {}; actions: {}".format(
                        field, val.shape, actions
                    )
                )
コード例 #11
0
    def __iter__(self):
        t = tqdm(total=self.dataloader_size, desc="iterating dataloader")
        for batch in self.dataloader:
            batch_size = get_batch_size(batch)
            yield batch
            t.update(batch_size)
            SummaryWriterContext.increase_global_step()

        # clean up if need to (e.g. Petastorm Dataloader)
        if hasattr(self.dataloader, "__exit__"):
            self.dataloader.__exit__(None, None, None)
コード例 #12
0
    def log_to_tensorboard(self, epoch: int) -> None:
        def none_to_zero(x: Optional[float]) -> float:
            if x is None or math.isnan(x):
                return 0.0
            return x

        for name, value in [
            ("Training/td_loss", self.get_recent_td_loss()),
            ("Training/reward_loss", self.get_recent_reward_loss()),
            ("Training/imitator_loss", self.get_recent_imitator_loss()),
        ]:
            SummaryWriterContext.add_scalar(name, none_to_zero(value), epoch)
コード例 #13
0
async def async_run_episode(
    env: EnvWrapper,
    agent: Agent,
    mdp_id: int = 0,
    max_steps: Optional[int] = None,
    fill_info: bool = False,
) -> Trajectory:
    """
    NOTE: this funciton is an async coroutine in order to support async env.step(). If you are using
        it with regular env.step() method, use non-async run_episode(), which wraps this function.
    Return sum of rewards from episode.
    After max_steps (if specified), the environment is assumed to be terminal.
    Can also specify the mdp_id and gamma of episode.
    """
    trajectory = Trajectory()
    obs = env.reset()
    possible_actions_mask = env.possible_actions_mask
    terminal = False
    num_steps = 0
    step_is_coroutine = asyncio.iscoroutinefunction(env.step)
    while not terminal:
        action, log_prob = agent.act(obs, possible_actions_mask)
        if step_is_coroutine:
            next_obs, reward, terminal, info = await env.step(action)
        else:
            next_obs, reward, terminal, info = env.step(action)
        if not fill_info:
            info = None
        next_possible_actions_mask = env.possible_actions_mask
        if max_steps is not None and num_steps >= max_steps:
            terminal = True

        # Only partially filled. Agent can fill in more fields.
        transition = Transition(
            mdp_id=mdp_id,
            sequence_number=num_steps,
            observation=obs,
            action=action,
            reward=float(reward),
            terminal=bool(terminal),
            log_prob=log_prob,
            possible_actions_mask=possible_actions_mask,
            info=info,
        )
        agent.post_step(transition)
        trajectory.add_transition(transition)
        SummaryWriterContext.increase_global_step()
        obs = next_obs
        possible_actions_mask = next_possible_actions_mask
        num_steps += 1
    agent.post_episode(trajectory)
    return trajectory
コード例 #14
0
    def test_minibatches_per_step(self):
        _epochs = self.epochs
        self.epochs = 2
        rl_parameters = RLParameters(gamma=0.95,
                                     target_update_rate=0.9,
                                     maxq_learning=True)
        rainbow_parameters = RainbowDQNParameters(double_q_learning=True,
                                                  dueling_architecture=False)
        training_parameters1 = TrainingParameters(
            layers=self.layers,
            activations=self.activations,
            minibatch_size=1024,
            minibatches_per_step=1,
            learning_rate=0.25,
            optimizer="ADAM",
        )
        training_parameters2 = TrainingParameters(
            layers=self.layers,
            activations=self.activations,
            minibatch_size=128,
            minibatches_per_step=8,
            learning_rate=0.25,
            optimizer="ADAM",
        )
        env1 = Env(self.state_dims, self.action_dims)
        env2 = Env(self.state_dims, self.action_dims)
        model_parameters1 = DiscreteActionModelParameters(
            actions=env1.actions,
            rl=rl_parameters,
            rainbow=rainbow_parameters,
            training=training_parameters1,
        )
        model_parameters2 = DiscreteActionModelParameters(
            actions=env2.actions,
            rl=rl_parameters,
            rainbow=rainbow_parameters,
            training=training_parameters2,
        )
        # minibatch_size / 8, minibatches_per_step * 8 should give the same result
        logger.info("Training model 1")
        trainer1 = self._train(model_parameters1, env1)
        SummaryWriterContext._reset_globals()
        logger.info("Training model 2")
        trainer2 = self._train(model_parameters2, env2)

        weight1 = trainer1.q_network.fc.dnn[-2].weight.detach().numpy()
        weight2 = trainer2.q_network.fc.dnn[-2].weight.detach().numpy()

        # Due to numerical stability this tolerance has to be fairly high
        self.assertTrue(np.allclose(weight1, weight2, rtol=0.0, atol=1e-3))
        self.epochs = _epochs
コード例 #15
0
    def add_custom_scalars(action_names: Optional[List[str]]):
        if not action_names:
            return

        SummaryWriterContext.add_custom_scalars_multilinechart(
            [
                "propensities/model/{}/mean".format(action_name)
                for action_name in action_names
            ],
            category="propensities",
            title="model",
        )
        SummaryWriterContext.add_custom_scalars_multilinechart(
            [
                "propensities/logged/{}/mean".format(action_name)
                for action_name in action_names
            ],
            category="propensities",
            title="logged",
        )
        SummaryWriterContext.add_custom_scalars_multilinechart(
            ["actions/logged/{}".format(action_name) for action_name in action_names],
            category="actions",
            title="logged",
        )
        SummaryWriterContext.add_custom_scalars_multilinechart(
            ["actions/model/{}".format(action_name) for action_name in action_names],
            category="actions",
            title="model",
        )
コード例 #16
0
ファイル: actor.py プロジェクト: zhaonann/ReAgent
    def get_log_prob(self, state, squashed_action):
        """
        Action is expected to be squashed with tanh
        """
        loc, scale_log = self._get_loc_and_scale_log(state)
        # This is not getting exported; we can use it
        n = Normal(loc, scale_log.exp())
        raw_action = self._atanh(squashed_action)

        log_prob = n.log_prob(raw_action)
        squash_correction = self._squash_correction(squashed_action)
        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram("actor/get_log_prob/loc",
                                               loc.detach().cpu())
            SummaryWriterContext.add_histogram("actor/get_log_prob/scale_log",
                                               scale_log.detach().cpu())
            SummaryWriterContext.add_histogram("actor/get_log_prob/log_prob",
                                               log_prob.detach().cpu())
            SummaryWriterContext.add_histogram(
                "actor/get_log_prob/squash_correction",
                squash_correction.detach().cpu())
        log_prob = torch.sum(log_prob - squash_correction,
                             dim=1).reshape(-1, 1)

        return log_prob
コード例 #17
0
ファイル: actor.py プロジェクト: zrion/ReAgent
    def get_log_prob(self, state: rlt.FeatureData,
                     squashed_action: torch.Tensor):
        """
        Action is expected to be squashed with tanh
        """
        if self.use_l2_normalization:
            # TODO: calculate log_prob for l2 normalization
            # https://math.stackexchange.com/questions/3120506/on-the-distribution-of-a-normalized-gaussian-vector
            # http://proceedings.mlr.press/v100/mazoure20a/mazoure20a.pdf
            pass

        loc, scale_log = self._get_loc_and_scale_log(state)
        raw_action = torch.atanh(squashed_action)
        r = (raw_action - loc) / scale_log.exp()
        log_prob = self._normal_log_prob(r, scale_log)
        squash_correction = self._squash_correction(squashed_action)
        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram("actor/get_log_prob/loc",
                                               loc.detach().cpu())
            SummaryWriterContext.add_histogram("actor/get_log_prob/scale_log",
                                               scale_log.detach().cpu())
            SummaryWriterContext.add_histogram("actor/get_log_prob/log_prob",
                                               log_prob.detach().cpu())
            SummaryWriterContext.add_histogram(
                "actor/get_log_prob/squash_correction",
                squash_correction.detach().cpu())
        return torch.sum(log_prob - squash_correction, dim=1).reshape(-1, 1)
コード例 #18
0
ファイル: actor.py プロジェクト: zwcdp/ReAgent
    def forward(self, input):
        loc, scale_log = self._get_loc_and_scale_log(input.state)
        r = torch.randn_like(scale_log, device=scale_log.device)
        action = torch.tanh(loc + r * scale_log.exp())
        if not self.training:
            # ONNX doesn't like reshape either..
            return rlt.ActorOutput(action=action)
        # Since each dim are independent, log-prob is simply sum
        log_prob = self._log_prob(r, scale_log)
        squash_correction = self._squash_correction(action)
        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram("actor/forward/loc", loc.detach().cpu())
            SummaryWriterContext.add_histogram(
                "actor/forward/scale_log", scale_log.detach().cpu()
            )
            SummaryWriterContext.add_histogram(
                "actor/forward/log_prob", log_prob.detach().cpu()
            )
            SummaryWriterContext.add_histogram(
                "actor/forward/squash_correction", squash_correction.detach().cpu()
            )
        log_prob = torch.sum(log_prob - squash_correction, dim=1)

        return rlt.ActorOutput(
            action=action, log_prob=log_prob.reshape(-1, 1), action_mean=loc
        )
コード例 #19
0
 def __init__(
     self,
     key: str,
     category: str,
     title: str,
     actions: List[str],
     log_key_prefix: Optional[str] = None,
 ):
     super().__init__(key)
     self.log_key_prefix = log_key_prefix or f"{category}/{title}"
     self.actions = actions
     SummaryWriterContext.add_custom_scalars_multilinechart(
         [f"{self.log_key_prefix}/{action_name}/mean" for action_name in actions],
         category=category,
         title=title,
     )
コード例 #20
0
ファイル: actor.py プロジェクト: zrion/ReAgent
    def forward(self, state: rlt.FeatureData):
        loc, scale_log = self._get_loc_and_scale_log(state)
        r = torch.randn_like(scale_log, device=scale_log.device)
        raw_action = loc + r * scale_log.exp()
        squashed_action = self._squash_raw_action(raw_action)
        squashed_loc = self._squash_raw_action(loc)
        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram("actor/forward/loc",
                                               loc.detach().cpu())
            SummaryWriterContext.add_histogram("actor/forward/scale_log",
                                               scale_log.detach().cpu())

        return rlt.ActorOutput(
            action=squashed_action,
            log_prob=self.get_log_prob(state, squashed_action),
            squashed_mean=squashed_loc,
        )
コード例 #21
0
 def test_writing_stack(self):
     with TemporaryDirectory() as tmp_dir1, TemporaryDirectory(
     ) as tmp_dir2:
         writer1 = SummaryWriter(tmp_dir1)
         writer1.add_scalar = MagicMock()
         writer2 = SummaryWriter(tmp_dir2)
         writer2.add_scalar = MagicMock()
         with summary_writer_context(writer1):
             with summary_writer_context(writer2):
                 SummaryWriterContext.add_scalar("test2", torch.ones(1))
             SummaryWriterContext.add_scalar("test1", torch.zeros(1))
         writer1.add_scalar.assert_called_once_with("test1",
                                                    torch.zeros(1),
                                                    global_step=0)
         writer2.add_scalar.assert_called_once_with("test2",
                                                    torch.ones(1),
                                                    global_step=0)
コード例 #22
0
ファイル: base_workflow.py プロジェクト: zwcdp/ReAgent
    def train_network(self, train_dataset, eval_dataset, epochs: int):
        num_batches = int(len(train_dataset) / self.minibatch_size)
        logger.info("Read in batch data set of size {} examples. Data split "
                    "into {} batches of size {}.".format(
                        len(train_dataset), num_batches, self.minibatch_size))

        start_time = time.time()
        for epoch in range(epochs):
            train_dataset.reset_iterator()
            data_streamer = DataStreamer(train_dataset,
                                         pin_memory=self.trainer.use_gpu)

            feed_pages(
                data_streamer,
                len(train_dataset),
                epoch,
                self.minibatch_size,
                self.trainer.use_gpu,
                TrainingPageHandler(self.trainer),
                batch_preprocessor=self.batch_preprocessor,
            )

            if hasattr(self.trainer, "q_network_cpe"):
                # TODO: Add CPE support to SAC
                eval_dataset.reset_iterator()
                data_streamer = DataStreamer(eval_dataset,
                                             pin_memory=self.trainer.use_gpu)
                eval_page_handler = EvaluationPageHandler(
                    self.trainer, self.evaluator, self)
                feed_pages(
                    data_streamer,
                    len(eval_dataset),
                    epoch,
                    self.minibatch_size,
                    self.trainer.use_gpu,
                    eval_page_handler,
                    batch_preprocessor=self.batch_preprocessor,
                )

                SummaryWriterContext.increase_global_step()

        through_put = (len(train_dataset) * epochs) / (time.time() -
                                                       start_time)
        logger.info("Training finished. Processed ~{} examples / s.".format(
            round(through_put)))
コード例 #23
0
    def training_step(self, batch, batch_idx: int, optimizer_idx: int):
        if self._training_step_generator is None:
            self._training_step_generator = self.train_step_gen(batch, batch_idx)

        ret = next(self._training_step_generator)

        if optimizer_idx == self._num_optimizing_steps - 1:
            if not self._verified_steps:
                try:
                    next(self._training_step_generator)
                except StopIteration:
                    self._verified_steps = True
                if not self._verified_steps:
                    raise RuntimeError("training_step_gen() yields too many times")
            self._training_step_generator = None
            SummaryWriterContext.increase_global_step()

        return ret
コード例 #24
0
def run_episode(env: EnvWrapper,
                agent: Agent,
                mdp_id: int = 0,
                max_steps: Optional[int] = None) -> Trajectory:
    """
    Return sum of rewards from episode.
    After max_steps (if specified), the environment is assumed to be terminal.
    Can also specify the mdp_id and gamma of episode.
    """
    trajectory = Trajectory()
    # pyre-fixme[16]: `EnvWrapper` has no attribute `reset`.
    obs = env.reset()
    possible_actions_mask = env.possible_actions_mask
    terminal = False
    num_steps = 0
    while not terminal:
        action, log_prob = agent.act(obs, possible_actions_mask)
        # pyre-fixme[16]: `EnvWrapper` has no attribute `step`.
        next_obs, reward, terminal, _ = env.step(action)
        next_possible_actions_mask = env.possible_actions_mask
        if max_steps is not None and num_steps >= max_steps:
            terminal = True

        # Only partially filled. Agent can fill in more fields.
        transition = Transition(
            mdp_id=mdp_id,
            sequence_number=num_steps,
            observation=obs,
            action=action,
            reward=float(reward),
            terminal=bool(terminal),
            log_prob=log_prob,
            possible_actions_mask=possible_actions_mask,
        )
        agent.post_step(transition)
        trajectory.add_transition(transition)
        SummaryWriterContext.increase_global_step()
        obs = next_obs
        possible_actions_mask = next_possible_actions_mask
        num_steps += 1
    agent.post_episode(trajectory)
    return trajectory
コード例 #25
0
def run_episode(env: Env,
                agent: Agent,
                max_steps: Optional[int] = None) -> float:
    """
    Return sum of rewards from episode.
    After max_steps (if specified), the environment is assumed to be terminal.
    """
    ep_reward = 0.0
    obs = env.reset()
    terminal = False
    num_steps = 0
    while not terminal:
        action = agent.act(obs)
        next_obs, reward, terminal, _ = env.step(action)
        obs = next_obs
        ep_reward += reward
        num_steps += 1
        if max_steps is not None and num_steps > max_steps:
            terminal = True

        agent.post_step(reward, terminal)
        SummaryWriterContext.increase_global_step()
    return ep_reward
コード例 #26
0
 def _sample_action(self, loc: torch.Tensor, scale_log: torch.Tensor):
     r = torch.randn_like(scale_log, device=scale_log.device)
     action = torch.tanh(loc + r * scale_log.exp())
     # Since each dim are independent, log-prob is simply sum
     log_prob = self.actor_network._log_prob(r, scale_log)
     squash_correction = self.actor_network._squash_correction(action)
     if SummaryWriterContext._global_step % 1000 == 0:
         SummaryWriterContext.add_histogram("actor/forward/loc",
                                            loc.detach().cpu())
         SummaryWriterContext.add_histogram("actor/forward/scale_log",
                                            scale_log.detach().cpu())
         SummaryWriterContext.add_histogram("actor/forward/log_prob",
                                            log_prob.detach().cpu())
         SummaryWriterContext.add_histogram(
             "actor/forward/squash_correction",
             squash_correction.detach().cpu())
     log_prob = torch.sum(log_prob - squash_correction, dim=1)
     return action, log_prob
コード例 #27
0
 def test_global_step(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock()
         with summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
             SummaryWriterContext.increase_global_step()
             SummaryWriterContext.add_scalar("test", torch.zeros(1))
         writer.add_scalar.assert_has_calls([
             call("test", torch.ones(1), global_step=0),
             call("test", torch.zeros(1), global_step=1),
         ])
         self.assertEqual(2, len(writer.add_scalar.mock_calls))
コード例 #28
0
 def _log_prob(self, loc: torch.Tensor, scale_log: torch.Tensor,
               squashed_action: torch.Tensor):
     # This is not getting exported; we can use it
     n = torch.distributions.Normal(loc, scale_log.exp())
     raw_action = self.actor_network._atanh(squashed_action)
     log_prob = n.log_prob(raw_action)
     squash_correction = self.actor_network._squash_correction(
         squashed_action)
     if SummaryWriterContext._global_step % 1000 == 0:
         SummaryWriterContext.add_histogram("actor/get_log_prob/loc",
                                            loc.detach().cpu())
         SummaryWriterContext.add_histogram("actor/get_log_prob/scale_log",
                                            scale_log.detach().cpu())
         SummaryWriterContext.add_histogram("actor/get_log_prob/log_prob",
                                            log_prob.detach().cpu())
         SummaryWriterContext.add_histogram(
             "actor/get_log_prob/squash_correction",
             squash_correction.detach().cpu())
     log_prob = torch.sum(log_prob - squash_correction, dim=1)
     return log_prob
コード例 #29
0
ファイル: td3_trainer.py プロジェクト: zachkeer/ReAgent
    def train(self, training_batch: rlt.PolicyNetworkInput) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        assert isinstance(training_batch, rlt.PolicyNetworkInput)

        self.minibatch += 1

        state = training_batch.state
        action = training_batch.action
        next_state = training_batch.next_state
        reward = training_batch.reward
        not_terminal = training_batch.not_terminal

        # Generate target = r + y * min (Q1(s',pi(s')), Q2(s',pi(s')))
        with torch.no_grad():
            next_actor = self.actor_network_target(next_state).action
            noise = torch.randn_like(next_actor) * self.noise_variance
            next_actor = (next_actor +
                          noise.clamp(*self.noise_clip_range)).clamp(
                              *CONTINUOUS_TRAINING_ACTION_RANGE)
            next_state_actor = (next_state, rlt.FeatureData(next_actor))
            next_q_value = self.q1_network_target(*next_state_actor)

            if self.q2_network is not None:
                next_q_value = torch.min(
                    next_q_value, self.q2_network_target(*next_state_actor))

            target_q_value = reward + self.gamma * next_q_value * not_terminal.float(
            )

        # Optimize Q1 and Q2
        # NOTE: important to zero here (instead of using _maybe_update)
        # since q1 may have accumulated gradients from actor network update
        self.q1_network_optimizer.zero_grad()
        q1_value = self.q1_network(state, action)
        q1_loss = self.q_network_loss(q1_value, target_q_value)
        q1_loss.backward()
        self.q1_network_optimizer.step()

        if self.q2_network:
            self.q2_network_optimizer.zero_grad()
            q2_value = self.q2_network(state, action)
            q2_loss = self.q_network_loss(q2_value, target_q_value)
            q2_loss.backward()
            self.q2_network_optimizer.step()

        # Only update actor and target networks after a fixed number of Q updates
        if self.minibatch % self.delayed_policy_update == 0:
            self.actor_network_optimizer.zero_grad()
            actor_action = self.actor_network(state).action
            actor_q1_value = self.q1_network(state,
                                             rlt.FeatureData(actor_action))
            actor_loss = -(actor_q1_value.mean())
            actor_loss.backward()
            self.actor_network_optimizer.step()

            self._soft_update(self.q1_network, self.q1_network_target,
                              self.tau)
            self._soft_update(self.q2_network, self.q2_network_target,
                              self.tau)
            self._soft_update(self.actor_network, self.actor_network_target,
                              self.tau)

        # Logging at the end to schedule all the cuda operations first
        if (self.tensorboard_logging_freq != 0
                and self.minibatch % self.tensorboard_logging_freq == 0):
            logs = {
                "loss/q1_loss": q1_loss,
                "loss/actor_loss": actor_loss,
                "q_value/q1_value": q1_value,
                "q_value/next_q_value": next_q_value,
                "q_value/target_q_value": target_q_value,
                "q_value/actor_q1_value": actor_q1_value,
            }
            if self.q2_network:
                logs.update({
                    "loss/q2_loss": q2_loss,
                    "q_value/q2_value": q2_value
                })

            for k, v in logs.items():
                v = v.detach().cpu()
                if v.dim() == 0:
                    # pyre-fixme[16]: `SummaryWriterContext` has no attribute
                    #  `add_scalar`.
                    SummaryWriterContext.add_scalar(k, v.item())
                    continue

                elif v.dim() == 2:
                    v = v.squeeze(1)
                assert v.dim() == 1
                SummaryWriterContext.add_histogram(k, v.numpy())
                SummaryWriterContext.add_scalar(f"{k}_mean", v.mean().item())

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
        )
コード例 #30
0
    def train(self, training_batch: rlt.PolicyNetworkInput) -> None:
        """
        IMPORTANT: the input action here is assumed to match the
        range of the output of the actor.
        """
        if isinstance(training_batch, TrainingDataPage):
            training_batch = training_batch.as_policy_network_training_batch()

        assert isinstance(training_batch, rlt.PolicyNetworkInput)

        self.minibatch += 1

        state = training_batch.state
        action = training_batch.action
        reward = training_batch.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = training_batch.not_terminal

        # We need to zero out grad here because gradient from actor update
        # should not be used in Q-network update
        self.actor_network_optimizer.zero_grad()
        self.q1_network_optimizer.zero_grad()
        if self.q2_network is not None:
            self.q2_network_optimizer.zero_grad()
        if self.value_network is not None:
            self.value_network_optimizer.zero_grad()

        with torch.enable_grad():
            #
            # First, optimize Q networks; minimizing MSE between
            # Q(s, a) & r + discount * V'(next_s)
            #

            q1_value = self.q1_network(state, action)
            if self.q2_network:
                q2_value = self.q2_network(state, action)
            actor_output = self.actor_network(state)

            # Optimize Alpha
            if self.alpha_optimizer is not None:
                alpha_loss = -((self.log_alpha *
                                (actor_output.log_prob +
                                 self.target_entropy).detach()).mean())
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                self.entropy_temperature = self.log_alpha.exp()

            with torch.no_grad():
                if self.value_network is not None:
                    next_state_value = self.value_network_target(
                        training_batch.next_state.float_features)
                else:
                    next_state_actor_output = self.actor_network(
                        training_batch.next_state)
                    next_state_actor_action = (
                        training_batch.next_state,
                        rlt.FeatureData(next_state_actor_output.action),
                    )
                    next_state_value = self.q1_network_target(
                        *next_state_actor_action)

                    if self.q2_network is not None:
                        target_q2_value = self.q2_network_target(
                            *next_state_actor_action)
                        next_state_value = torch.min(next_state_value,
                                                     target_q2_value)

                    log_prob_a = self.actor_network.get_log_prob(
                        training_batch.next_state,
                        next_state_actor_output.action)
                    log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                    next_state_value -= self.entropy_temperature * log_prob_a

                if self.gamma > 0.0:
                    target_q_value = (
                        reward +
                        discount * next_state_value * not_done_mask.float())
                else:
                    # This is useful in debugging instability issues
                    target_q_value = reward

            q1_loss = F.mse_loss(q1_value, target_q_value)
            q1_loss.backward()
            self._maybe_run_optimizer(self.q1_network_optimizer,
                                      self.minibatches_per_step)
            if self.q2_network:
                # pyre-fixme[18]: Global name `q2_value` is undefined.
                q2_loss = F.mse_loss(q2_value, target_q_value)
                q2_loss.backward()
                self._maybe_run_optimizer(self.q2_network_optimizer,
                                          self.minibatches_per_step)

            # Second, optimize the actor; minimizing KL-divergence between
            # propensity & softmax of value.  Due to reparameterization trick,
            # it ends up being log_prob(actor_action) - Q(s, actor_action)

            state_actor_action = (state, rlt.FeatureData(actor_output.action))
            q1_actor_value = self.q1_network(*state_actor_action)
            min_q_actor_value = q1_actor_value
            if self.q2_network:
                q2_actor_value = self.q2_network(*state_actor_action)
                min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

            actor_loss = (self.entropy_temperature * actor_output.log_prob -
                          min_q_actor_value)
            # Do this in 2 steps so we can log histogram of actor loss
            actor_loss_mean = actor_loss.mean()

            if self.add_kld_to_loss:
                if self.apply_kld_on_mean:
                    action_batch_m = torch.mean(actor_output.action_mean,
                                                axis=0)
                    action_batch_v = torch.var(actor_output.action_mean,
                                               axis=0)
                else:
                    action_batch_m = torch.mean(actor_output.action, axis=0)
                    action_batch_v = torch.var(actor_output.action, axis=0)
                kld = (
                    0.5
                    # pyre-fixme[16]: `int` has no attribute `sum`.
                    * ((action_batch_v +
                        (action_batch_m - self.action_emb_mean)**2) /
                       self.action_emb_variance - 1 +
                       self.action_emb_variance.log() -
                       action_batch_v.log()).sum())

                actor_loss_mean += self.kld_weight * kld

            actor_loss_mean.backward()
            self._maybe_run_optimizer(self.actor_network_optimizer,
                                      self.minibatches_per_step)

            #
            # Lastly, if applicable, optimize value network; minimizing MSE between
            # V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ]
            #

            if self.value_network is not None:
                state_value = self.value_network(state.float_features)

                if self.logged_action_uniform_prior:
                    log_prob_a = torch.zeros_like(min_q_actor_value)
                    target_value = min_q_actor_value
                else:
                    with torch.no_grad():
                        log_prob_a = actor_output.log_prob
                        log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                        target_value = (min_q_actor_value -
                                        self.entropy_temperature * log_prob_a)

                value_loss = F.mse_loss(state_value, target_value.detach())
                value_loss.backward()
                self._maybe_run_optimizer(self.value_network_optimizer,
                                          self.minibatches_per_step)

        # Use the soft update rule to update the target networks
        if self.value_network is not None:
            self._maybe_soft_update(
                self.value_network,
                self.value_network_target,
                self.tau,
                self.minibatches_per_step,
            )
        else:
            self._maybe_soft_update(
                self.q1_network,
                self.q1_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            if self.q2_network is not None:
                self._maybe_soft_update(
                    self.q2_network,
                    self.q2_network_target,
                    self.tau,
                    self.minibatches_per_step,
                )

        # Logging at the end to schedule all the cuda operations first
        if (self.tensorboard_logging_freq != 0
                and self.minibatch % self.tensorboard_logging_freq == 0):
            SummaryWriterContext.add_histogram("q1/logged_state_value",
                                               q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value",
                                                   q2_value)

            # pyre-fixme[16]: `SummaryWriterContext` has no attribute `add_scalar`.
            SummaryWriterContext.add_scalar("entropy_temperature",
                                            self.entropy_temperature)
            SummaryWriterContext.add_histogram("log_prob_a", log_prob_a)
            if self.value_network:
                SummaryWriterContext.add_histogram("value_network/target",
                                                   target_value)

            SummaryWriterContext.add_histogram("q_network/next_state_value",
                                               next_state_value)
            SummaryWriterContext.add_histogram("q_network/target_q_value",
                                               target_q_value)
            SummaryWriterContext.add_histogram("actor/min_q_actor_value",
                                               min_q_actor_value)
            SummaryWriterContext.add_histogram("actor/action_log_prob",
                                               actor_output.log_prob)
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)
            if self.add_kld_to_loss:
                SummaryWriterContext.add_histogram("kld/mean", action_batch_m)
                SummaryWriterContext.add_histogram("kld/var", action_batch_v)
                SummaryWriterContext.add_scalar("kld/kld", kld)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
            model_propensities=actor_output.log_prob.exp(),
            model_values=min_q_actor_value,
        )