def run_episode(env: Env, agent: Agent, mdp_id: int = 0, max_steps: Optional[int] = None) -> Trajectory: """ Return sum of rewards from episode. After max_steps (if specified), the environment is assumed to be terminal. Can also specify the mdp_id and gamma of episode. """ trajectory = Trajectory() obs = env.reset() terminal = False num_steps = 0 while not terminal: action = agent.act(obs) next_obs, reward, terminal, _ = env.step(action) if max_steps is not None and num_steps >= max_steps: terminal = True # Only partially filled. Agent can fill in more fields. transition = Transition( mdp_id=mdp_id, sequence_number=num_steps, observation=obs, action=action, reward=reward, terminal=terminal, ) agent.post_step(transition) trajectory.add_transition(transition) SummaryWriterContext.increase_global_step() obs = next_obs num_steps += 1 return trajectory
def test_add_custom_scalars(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_custom_scalars = MagicMock() with summary_writer_context(writer): SummaryWriterContext.add_custom_scalars_multilinechart( ["a", "b"], category="cat", title="title") with self.assertRaisesRegex( AssertionError, "Title \\(title\\) is already in category \\(cat\\)"): SummaryWriterContext.add_custom_scalars_multilinechart( ["c", "d"], category="cat", title="title") SummaryWriterContext.add_custom_scalars_multilinechart( ["e", "f"], category="cat", title="title2") SummaryWriterContext.add_custom_scalars_multilinechart( ["g", "h"], category="cat2", title="title") SummaryWriterContext.add_custom_scalars(writer) writer.add_custom_scalars.assert_called_once_with({ "cat": { "title": ["Multiline", ["a", "b"]], "title2": ["Multiline", ["e", "f"]], }, "cat2": { "title": ["Multiline", ["g", "h"]] }, })
def training_step(self, batch, batch_idx: int, optimizer_idx: int = 0): assert (optimizer_idx == 0) or (self._num_optimizing_steps > 1) if self._training_step_generator is None: if self._training_batch_type and isinstance(batch, dict): batch = self._training_batch_type.from_dict(batch) self._training_step_generator = self.train_step_gen( batch, batch_idx) ret = next(self._training_step_generator) if optimizer_idx == self._num_optimizing_steps - 1: if not self._verified_steps: try: next(self._training_step_generator) except StopIteration: self._verified_steps = True if not self._verified_steps: raise RuntimeError( "training_step_gen() yields too many times." "The number of yields should match the number of optimizers," f" in this case {self._num_optimizing_steps}") self._training_step_generator = None SummaryWriterContext.increase_global_step() return ret
def log_to_tensorboard(self, metric_name: str) -> None: self.check_estimates_exist() def none_to_zero(x: Optional[float]) -> float: if x is None or math.isnan(x): return 0.0 return x for name, value in [ ( "CPE/{}/Direct_Method_Reward".format(metric_name), # pyre-fixme[16]: `Optional` has no attribute `normalized`. self.direct_method.normalized, ), ( "CPE/{}/IPS_Reward".format(metric_name), self.inverse_propensity.normalized, ), ( "CPE/{}/Doubly_Robust_Reward".format(metric_name), self.doubly_robust.normalized, ), ( "CPE/{}/Sequential_Doubly_Robust".format(metric_name), self.sequential_doubly_robust.normalized, ), ( "CPE/{}/Weighted_Sequential_Doubly_Robust".format(metric_name), self.weighted_doubly_robust.normalized, ), ("CPE/{}/MAGIC".format(metric_name), self.magic.normalized), ]: SummaryWriterContext.add_scalar(name, none_to_zero(value))
def test_swallowing_exception(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_scalar = MagicMock( side_effect=NotImplementedError("test")) writer.exceptions_to_ignore = (NotImplementedError, KeyError) with summary_writer_context(writer): SummaryWriterContext.add_scalar("test", torch.ones(1))
def test_not_swallowing_exception(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_scalar = MagicMock( side_effect=NotImplementedError("test")) with self.assertRaisesRegex( NotImplementedError, "test"), summary_writer_context(writer): SummaryWriterContext.add_scalar("test", torch.ones(1))
def __init__(self, key: str, title: str, actions: List[str]): super().__init__(key) self.log_key = f"actions/{title}" self.actions = actions SummaryWriterContext.add_custom_scalars_multilinechart( [f"{self.log_key}/{action_name}" for action_name in actions], category="actions", title=title, )
def test_writing(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_scalar = MagicMock() with summary_writer_context(writer): SummaryWriterContext.add_scalar("test", torch.ones(1)) writer.add_scalar.assert_called_once_with("test", torch.ones(1), global_step=0)
def _log_histogram_and_mean(self, log_key, val): try: SummaryWriterContext.add_histogram(log_key, val) SummaryWriterContext.add_scalar(f"{log_key}/mean", val.mean()) except ValueError: logger.warning( f"Cannot create histogram for key: {log_key}; " "this is likely because you have NULL value in your input; " f"value: {val}") raise
def write_summary(self, actions: List[str]): if actions: for field, log_key in [ ("logged_actions", "actions/logged"), ("model_action_idxs", "actions/model"), ]: val = getattr(self, field) if val is None: continue for i, action in enumerate(actions): # pyre-fixme[16]: `SummaryWriterContext` has no attribute # `add_scalar`. SummaryWriterContext.add_scalar( "{}/{}".format(log_key, action), (val == i).sum().item() ) for field, log_key in [ ("td_loss", "td_loss"), ("imitator_loss", "imitator_loss"), ("reward_loss", "reward_loss"), ("logged_propensities", "propensities/logged"), ("logged_rewards", "reward/logged"), ("logged_values", "value/logged"), ("model_values_on_logged_actions", "value/model_logged_action"), ]: val = getattr(self, field) if val is None: continue assert len(val.shape) == 1 or ( len(val.shape) == 2 and val.shape[1] == 1 ), "Unexpected shape for {}: {}".format(field, val.shape) self._log_histogram_and_mean(log_key, val) for field, log_key in [ ("model_propensities", "propensities/model"), ("model_rewards", "reward/model"), ("model_values", "value/model"), ]: val = getattr(self, field) if val is None: continue if ( len(val.shape) == 1 or (len(val.shape) == 2 and val.shape[1] == 1) ) and not actions: self._log_histogram_and_mean(log_key, val) elif len(val.shape) == 2 and val.shape[1] == len(actions): for i, action in enumerate(actions): self._log_histogram_and_mean(f"{log_key}/{action}", val[:, i]) else: raise ValueError( "Unexpected shape for {}: {}; actions: {}".format( field, val.shape, actions ) )
def __iter__(self): t = tqdm(total=self.dataloader_size, desc="iterating dataloader") for batch in self.dataloader: batch_size = get_batch_size(batch) yield batch t.update(batch_size) SummaryWriterContext.increase_global_step() # clean up if need to (e.g. Petastorm Dataloader) if hasattr(self.dataloader, "__exit__"): self.dataloader.__exit__(None, None, None)
def log_to_tensorboard(self, epoch: int) -> None: def none_to_zero(x: Optional[float]) -> float: if x is None or math.isnan(x): return 0.0 return x for name, value in [ ("Training/td_loss", self.get_recent_td_loss()), ("Training/reward_loss", self.get_recent_reward_loss()), ("Training/imitator_loss", self.get_recent_imitator_loss()), ]: SummaryWriterContext.add_scalar(name, none_to_zero(value), epoch)
async def async_run_episode( env: EnvWrapper, agent: Agent, mdp_id: int = 0, max_steps: Optional[int] = None, fill_info: bool = False, ) -> Trajectory: """ NOTE: this funciton is an async coroutine in order to support async env.step(). If you are using it with regular env.step() method, use non-async run_episode(), which wraps this function. Return sum of rewards from episode. After max_steps (if specified), the environment is assumed to be terminal. Can also specify the mdp_id and gamma of episode. """ trajectory = Trajectory() obs = env.reset() possible_actions_mask = env.possible_actions_mask terminal = False num_steps = 0 step_is_coroutine = asyncio.iscoroutinefunction(env.step) while not terminal: action, log_prob = agent.act(obs, possible_actions_mask) if step_is_coroutine: next_obs, reward, terminal, info = await env.step(action) else: next_obs, reward, terminal, info = env.step(action) if not fill_info: info = None next_possible_actions_mask = env.possible_actions_mask if max_steps is not None and num_steps >= max_steps: terminal = True # Only partially filled. Agent can fill in more fields. transition = Transition( mdp_id=mdp_id, sequence_number=num_steps, observation=obs, action=action, reward=float(reward), terminal=bool(terminal), log_prob=log_prob, possible_actions_mask=possible_actions_mask, info=info, ) agent.post_step(transition) trajectory.add_transition(transition) SummaryWriterContext.increase_global_step() obs = next_obs possible_actions_mask = next_possible_actions_mask num_steps += 1 agent.post_episode(trajectory) return trajectory
def test_minibatches_per_step(self): _epochs = self.epochs self.epochs = 2 rl_parameters = RLParameters(gamma=0.95, target_update_rate=0.9, maxq_learning=True) rainbow_parameters = RainbowDQNParameters(double_q_learning=True, dueling_architecture=False) training_parameters1 = TrainingParameters( layers=self.layers, activations=self.activations, minibatch_size=1024, minibatches_per_step=1, learning_rate=0.25, optimizer="ADAM", ) training_parameters2 = TrainingParameters( layers=self.layers, activations=self.activations, minibatch_size=128, minibatches_per_step=8, learning_rate=0.25, optimizer="ADAM", ) env1 = Env(self.state_dims, self.action_dims) env2 = Env(self.state_dims, self.action_dims) model_parameters1 = DiscreteActionModelParameters( actions=env1.actions, rl=rl_parameters, rainbow=rainbow_parameters, training=training_parameters1, ) model_parameters2 = DiscreteActionModelParameters( actions=env2.actions, rl=rl_parameters, rainbow=rainbow_parameters, training=training_parameters2, ) # minibatch_size / 8, minibatches_per_step * 8 should give the same result logger.info("Training model 1") trainer1 = self._train(model_parameters1, env1) SummaryWriterContext._reset_globals() logger.info("Training model 2") trainer2 = self._train(model_parameters2, env2) weight1 = trainer1.q_network.fc.dnn[-2].weight.detach().numpy() weight2 = trainer2.q_network.fc.dnn[-2].weight.detach().numpy() # Due to numerical stability this tolerance has to be fairly high self.assertTrue(np.allclose(weight1, weight2, rtol=0.0, atol=1e-3)) self.epochs = _epochs
def add_custom_scalars(action_names: Optional[List[str]]): if not action_names: return SummaryWriterContext.add_custom_scalars_multilinechart( [ "propensities/model/{}/mean".format(action_name) for action_name in action_names ], category="propensities", title="model", ) SummaryWriterContext.add_custom_scalars_multilinechart( [ "propensities/logged/{}/mean".format(action_name) for action_name in action_names ], category="propensities", title="logged", ) SummaryWriterContext.add_custom_scalars_multilinechart( ["actions/logged/{}".format(action_name) for action_name in action_names], category="actions", title="logged", ) SummaryWriterContext.add_custom_scalars_multilinechart( ["actions/model/{}".format(action_name) for action_name in action_names], category="actions", title="model", )
def get_log_prob(self, state, squashed_action): """ Action is expected to be squashed with tanh """ loc, scale_log = self._get_loc_and_scale_log(state) # This is not getting exported; we can use it n = Normal(loc, scale_log.exp()) raw_action = self._atanh(squashed_action) log_prob = n.log_prob(raw_action) squash_correction = self._squash_correction(squashed_action) if SummaryWriterContext._global_step % 1000 == 0: SummaryWriterContext.add_histogram("actor/get_log_prob/loc", loc.detach().cpu()) SummaryWriterContext.add_histogram("actor/get_log_prob/scale_log", scale_log.detach().cpu()) SummaryWriterContext.add_histogram("actor/get_log_prob/log_prob", log_prob.detach().cpu()) SummaryWriterContext.add_histogram( "actor/get_log_prob/squash_correction", squash_correction.detach().cpu()) log_prob = torch.sum(log_prob - squash_correction, dim=1).reshape(-1, 1) return log_prob
def get_log_prob(self, state: rlt.FeatureData, squashed_action: torch.Tensor): """ Action is expected to be squashed with tanh """ if self.use_l2_normalization: # TODO: calculate log_prob for l2 normalization # https://math.stackexchange.com/questions/3120506/on-the-distribution-of-a-normalized-gaussian-vector # http://proceedings.mlr.press/v100/mazoure20a/mazoure20a.pdf pass loc, scale_log = self._get_loc_and_scale_log(state) raw_action = torch.atanh(squashed_action) r = (raw_action - loc) / scale_log.exp() log_prob = self._normal_log_prob(r, scale_log) squash_correction = self._squash_correction(squashed_action) if SummaryWriterContext._global_step % 1000 == 0: SummaryWriterContext.add_histogram("actor/get_log_prob/loc", loc.detach().cpu()) SummaryWriterContext.add_histogram("actor/get_log_prob/scale_log", scale_log.detach().cpu()) SummaryWriterContext.add_histogram("actor/get_log_prob/log_prob", log_prob.detach().cpu()) SummaryWriterContext.add_histogram( "actor/get_log_prob/squash_correction", squash_correction.detach().cpu()) return torch.sum(log_prob - squash_correction, dim=1).reshape(-1, 1)
def forward(self, input): loc, scale_log = self._get_loc_and_scale_log(input.state) r = torch.randn_like(scale_log, device=scale_log.device) action = torch.tanh(loc + r * scale_log.exp()) if not self.training: # ONNX doesn't like reshape either.. return rlt.ActorOutput(action=action) # Since each dim are independent, log-prob is simply sum log_prob = self._log_prob(r, scale_log) squash_correction = self._squash_correction(action) if SummaryWriterContext._global_step % 1000 == 0: SummaryWriterContext.add_histogram("actor/forward/loc", loc.detach().cpu()) SummaryWriterContext.add_histogram( "actor/forward/scale_log", scale_log.detach().cpu() ) SummaryWriterContext.add_histogram( "actor/forward/log_prob", log_prob.detach().cpu() ) SummaryWriterContext.add_histogram( "actor/forward/squash_correction", squash_correction.detach().cpu() ) log_prob = torch.sum(log_prob - squash_correction, dim=1) return rlt.ActorOutput( action=action, log_prob=log_prob.reshape(-1, 1), action_mean=loc )
def __init__( self, key: str, category: str, title: str, actions: List[str], log_key_prefix: Optional[str] = None, ): super().__init__(key) self.log_key_prefix = log_key_prefix or f"{category}/{title}" self.actions = actions SummaryWriterContext.add_custom_scalars_multilinechart( [f"{self.log_key_prefix}/{action_name}/mean" for action_name in actions], category=category, title=title, )
def forward(self, state: rlt.FeatureData): loc, scale_log = self._get_loc_and_scale_log(state) r = torch.randn_like(scale_log, device=scale_log.device) raw_action = loc + r * scale_log.exp() squashed_action = self._squash_raw_action(raw_action) squashed_loc = self._squash_raw_action(loc) if SummaryWriterContext._global_step % 1000 == 0: SummaryWriterContext.add_histogram("actor/forward/loc", loc.detach().cpu()) SummaryWriterContext.add_histogram("actor/forward/scale_log", scale_log.detach().cpu()) return rlt.ActorOutput( action=squashed_action, log_prob=self.get_log_prob(state, squashed_action), squashed_mean=squashed_loc, )
def test_writing_stack(self): with TemporaryDirectory() as tmp_dir1, TemporaryDirectory( ) as tmp_dir2: writer1 = SummaryWriter(tmp_dir1) writer1.add_scalar = MagicMock() writer2 = SummaryWriter(tmp_dir2) writer2.add_scalar = MagicMock() with summary_writer_context(writer1): with summary_writer_context(writer2): SummaryWriterContext.add_scalar("test2", torch.ones(1)) SummaryWriterContext.add_scalar("test1", torch.zeros(1)) writer1.add_scalar.assert_called_once_with("test1", torch.zeros(1), global_step=0) writer2.add_scalar.assert_called_once_with("test2", torch.ones(1), global_step=0)
def train_network(self, train_dataset, eval_dataset, epochs: int): num_batches = int(len(train_dataset) / self.minibatch_size) logger.info("Read in batch data set of size {} examples. Data split " "into {} batches of size {}.".format( len(train_dataset), num_batches, self.minibatch_size)) start_time = time.time() for epoch in range(epochs): train_dataset.reset_iterator() data_streamer = DataStreamer(train_dataset, pin_memory=self.trainer.use_gpu) feed_pages( data_streamer, len(train_dataset), epoch, self.minibatch_size, self.trainer.use_gpu, TrainingPageHandler(self.trainer), batch_preprocessor=self.batch_preprocessor, ) if hasattr(self.trainer, "q_network_cpe"): # TODO: Add CPE support to SAC eval_dataset.reset_iterator() data_streamer = DataStreamer(eval_dataset, pin_memory=self.trainer.use_gpu) eval_page_handler = EvaluationPageHandler( self.trainer, self.evaluator, self) feed_pages( data_streamer, len(eval_dataset), epoch, self.minibatch_size, self.trainer.use_gpu, eval_page_handler, batch_preprocessor=self.batch_preprocessor, ) SummaryWriterContext.increase_global_step() through_put = (len(train_dataset) * epochs) / (time.time() - start_time) logger.info("Training finished. Processed ~{} examples / s.".format( round(through_put)))
def training_step(self, batch, batch_idx: int, optimizer_idx: int): if self._training_step_generator is None: self._training_step_generator = self.train_step_gen(batch, batch_idx) ret = next(self._training_step_generator) if optimizer_idx == self._num_optimizing_steps - 1: if not self._verified_steps: try: next(self._training_step_generator) except StopIteration: self._verified_steps = True if not self._verified_steps: raise RuntimeError("training_step_gen() yields too many times") self._training_step_generator = None SummaryWriterContext.increase_global_step() return ret
def run_episode(env: EnvWrapper, agent: Agent, mdp_id: int = 0, max_steps: Optional[int] = None) -> Trajectory: """ Return sum of rewards from episode. After max_steps (if specified), the environment is assumed to be terminal. Can also specify the mdp_id and gamma of episode. """ trajectory = Trajectory() # pyre-fixme[16]: `EnvWrapper` has no attribute `reset`. obs = env.reset() possible_actions_mask = env.possible_actions_mask terminal = False num_steps = 0 while not terminal: action, log_prob = agent.act(obs, possible_actions_mask) # pyre-fixme[16]: `EnvWrapper` has no attribute `step`. next_obs, reward, terminal, _ = env.step(action) next_possible_actions_mask = env.possible_actions_mask if max_steps is not None and num_steps >= max_steps: terminal = True # Only partially filled. Agent can fill in more fields. transition = Transition( mdp_id=mdp_id, sequence_number=num_steps, observation=obs, action=action, reward=float(reward), terminal=bool(terminal), log_prob=log_prob, possible_actions_mask=possible_actions_mask, ) agent.post_step(transition) trajectory.add_transition(transition) SummaryWriterContext.increase_global_step() obs = next_obs possible_actions_mask = next_possible_actions_mask num_steps += 1 agent.post_episode(trajectory) return trajectory
def run_episode(env: Env, agent: Agent, max_steps: Optional[int] = None) -> float: """ Return sum of rewards from episode. After max_steps (if specified), the environment is assumed to be terminal. """ ep_reward = 0.0 obs = env.reset() terminal = False num_steps = 0 while not terminal: action = agent.act(obs) next_obs, reward, terminal, _ = env.step(action) obs = next_obs ep_reward += reward num_steps += 1 if max_steps is not None and num_steps > max_steps: terminal = True agent.post_step(reward, terminal) SummaryWriterContext.increase_global_step() return ep_reward
def _sample_action(self, loc: torch.Tensor, scale_log: torch.Tensor): r = torch.randn_like(scale_log, device=scale_log.device) action = torch.tanh(loc + r * scale_log.exp()) # Since each dim are independent, log-prob is simply sum log_prob = self.actor_network._log_prob(r, scale_log) squash_correction = self.actor_network._squash_correction(action) if SummaryWriterContext._global_step % 1000 == 0: SummaryWriterContext.add_histogram("actor/forward/loc", loc.detach().cpu()) SummaryWriterContext.add_histogram("actor/forward/scale_log", scale_log.detach().cpu()) SummaryWriterContext.add_histogram("actor/forward/log_prob", log_prob.detach().cpu()) SummaryWriterContext.add_histogram( "actor/forward/squash_correction", squash_correction.detach().cpu()) log_prob = torch.sum(log_prob - squash_correction, dim=1) return action, log_prob
def test_global_step(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_scalar = MagicMock() with summary_writer_context(writer): SummaryWriterContext.add_scalar("test", torch.ones(1)) SummaryWriterContext.increase_global_step() SummaryWriterContext.add_scalar("test", torch.zeros(1)) writer.add_scalar.assert_has_calls([ call("test", torch.ones(1), global_step=0), call("test", torch.zeros(1), global_step=1), ]) self.assertEqual(2, len(writer.add_scalar.mock_calls))
def _log_prob(self, loc: torch.Tensor, scale_log: torch.Tensor, squashed_action: torch.Tensor): # This is not getting exported; we can use it n = torch.distributions.Normal(loc, scale_log.exp()) raw_action = self.actor_network._atanh(squashed_action) log_prob = n.log_prob(raw_action) squash_correction = self.actor_network._squash_correction( squashed_action) if SummaryWriterContext._global_step % 1000 == 0: SummaryWriterContext.add_histogram("actor/get_log_prob/loc", loc.detach().cpu()) SummaryWriterContext.add_histogram("actor/get_log_prob/scale_log", scale_log.detach().cpu()) SummaryWriterContext.add_histogram("actor/get_log_prob/log_prob", log_prob.detach().cpu()) SummaryWriterContext.add_histogram( "actor/get_log_prob/squash_correction", squash_correction.detach().cpu()) log_prob = torch.sum(log_prob - squash_correction, dim=1) return log_prob
def train(self, training_batch: rlt.PolicyNetworkInput) -> None: """ IMPORTANT: the input action here is assumed to be preprocessed to match the range of the output of the actor. """ assert isinstance(training_batch, rlt.PolicyNetworkInput) self.minibatch += 1 state = training_batch.state action = training_batch.action next_state = training_batch.next_state reward = training_batch.reward not_terminal = training_batch.not_terminal # Generate target = r + y * min (Q1(s',pi(s')), Q2(s',pi(s'))) with torch.no_grad(): next_actor = self.actor_network_target(next_state).action noise = torch.randn_like(next_actor) * self.noise_variance next_actor = (next_actor + noise.clamp(*self.noise_clip_range)).clamp( *CONTINUOUS_TRAINING_ACTION_RANGE) next_state_actor = (next_state, rlt.FeatureData(next_actor)) next_q_value = self.q1_network_target(*next_state_actor) if self.q2_network is not None: next_q_value = torch.min( next_q_value, self.q2_network_target(*next_state_actor)) target_q_value = reward + self.gamma * next_q_value * not_terminal.float( ) # Optimize Q1 and Q2 # NOTE: important to zero here (instead of using _maybe_update) # since q1 may have accumulated gradients from actor network update self.q1_network_optimizer.zero_grad() q1_value = self.q1_network(state, action) q1_loss = self.q_network_loss(q1_value, target_q_value) q1_loss.backward() self.q1_network_optimizer.step() if self.q2_network: self.q2_network_optimizer.zero_grad() q2_value = self.q2_network(state, action) q2_loss = self.q_network_loss(q2_value, target_q_value) q2_loss.backward() self.q2_network_optimizer.step() # Only update actor and target networks after a fixed number of Q updates if self.minibatch % self.delayed_policy_update == 0: self.actor_network_optimizer.zero_grad() actor_action = self.actor_network(state).action actor_q1_value = self.q1_network(state, rlt.FeatureData(actor_action)) actor_loss = -(actor_q1_value.mean()) actor_loss.backward() self.actor_network_optimizer.step() self._soft_update(self.q1_network, self.q1_network_target, self.tau) self._soft_update(self.q2_network, self.q2_network_target, self.tau) self._soft_update(self.actor_network, self.actor_network_target, self.tau) # Logging at the end to schedule all the cuda operations first if (self.tensorboard_logging_freq != 0 and self.minibatch % self.tensorboard_logging_freq == 0): logs = { "loss/q1_loss": q1_loss, "loss/actor_loss": actor_loss, "q_value/q1_value": q1_value, "q_value/next_q_value": next_q_value, "q_value/target_q_value": target_q_value, "q_value/actor_q1_value": actor_q1_value, } if self.q2_network: logs.update({ "loss/q2_loss": q2_loss, "q_value/q2_value": q2_value }) for k, v in logs.items(): v = v.detach().cpu() if v.dim() == 0: # pyre-fixme[16]: `SummaryWriterContext` has no attribute # `add_scalar`. SummaryWriterContext.add_scalar(k, v.item()) continue elif v.dim() == 2: v = v.squeeze(1) assert v.dim() == 1 SummaryWriterContext.add_histogram(k, v.numpy()) SummaryWriterContext.add_scalar(f"{k}_mean", v.mean().item()) self.loss_reporter.report( td_loss=float(q1_loss), reward_loss=None, logged_rewards=reward, model_values_on_logged_actions=q1_value, )
def train(self, training_batch: rlt.PolicyNetworkInput) -> None: """ IMPORTANT: the input action here is assumed to match the range of the output of the actor. """ if isinstance(training_batch, TrainingDataPage): training_batch = training_batch.as_policy_network_training_batch() assert isinstance(training_batch, rlt.PolicyNetworkInput) self.minibatch += 1 state = training_batch.state action = training_batch.action reward = training_batch.reward discount = torch.full_like(reward, self.gamma) not_done_mask = training_batch.not_terminal # We need to zero out grad here because gradient from actor update # should not be used in Q-network update self.actor_network_optimizer.zero_grad() self.q1_network_optimizer.zero_grad() if self.q2_network is not None: self.q2_network_optimizer.zero_grad() if self.value_network is not None: self.value_network_optimizer.zero_grad() with torch.enable_grad(): # # First, optimize Q networks; minimizing MSE between # Q(s, a) & r + discount * V'(next_s) # q1_value = self.q1_network(state, action) if self.q2_network: q2_value = self.q2_network(state, action) actor_output = self.actor_network(state) # Optimize Alpha if self.alpha_optimizer is not None: alpha_loss = -((self.log_alpha * (actor_output.log_prob + self.target_entropy).detach()).mean()) self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.entropy_temperature = self.log_alpha.exp() with torch.no_grad(): if self.value_network is not None: next_state_value = self.value_network_target( training_batch.next_state.float_features) else: next_state_actor_output = self.actor_network( training_batch.next_state) next_state_actor_action = ( training_batch.next_state, rlt.FeatureData(next_state_actor_output.action), ) next_state_value = self.q1_network_target( *next_state_actor_action) if self.q2_network is not None: target_q2_value = self.q2_network_target( *next_state_actor_action) next_state_value = torch.min(next_state_value, target_q2_value) log_prob_a = self.actor_network.get_log_prob( training_batch.next_state, next_state_actor_output.action) log_prob_a = log_prob_a.clamp(-20.0, 20.0) next_state_value -= self.entropy_temperature * log_prob_a if self.gamma > 0.0: target_q_value = ( reward + discount * next_state_value * not_done_mask.float()) else: # This is useful in debugging instability issues target_q_value = reward q1_loss = F.mse_loss(q1_value, target_q_value) q1_loss.backward() self._maybe_run_optimizer(self.q1_network_optimizer, self.minibatches_per_step) if self.q2_network: # pyre-fixme[18]: Global name `q2_value` is undefined. q2_loss = F.mse_loss(q2_value, target_q_value) q2_loss.backward() self._maybe_run_optimizer(self.q2_network_optimizer, self.minibatches_per_step) # Second, optimize the actor; minimizing KL-divergence between # propensity & softmax of value. Due to reparameterization trick, # it ends up being log_prob(actor_action) - Q(s, actor_action) state_actor_action = (state, rlt.FeatureData(actor_output.action)) q1_actor_value = self.q1_network(*state_actor_action) min_q_actor_value = q1_actor_value if self.q2_network: q2_actor_value = self.q2_network(*state_actor_action) min_q_actor_value = torch.min(q1_actor_value, q2_actor_value) actor_loss = (self.entropy_temperature * actor_output.log_prob - min_q_actor_value) # Do this in 2 steps so we can log histogram of actor loss actor_loss_mean = actor_loss.mean() if self.add_kld_to_loss: if self.apply_kld_on_mean: action_batch_m = torch.mean(actor_output.action_mean, axis=0) action_batch_v = torch.var(actor_output.action_mean, axis=0) else: action_batch_m = torch.mean(actor_output.action, axis=0) action_batch_v = torch.var(actor_output.action, axis=0) kld = ( 0.5 # pyre-fixme[16]: `int` has no attribute `sum`. * ((action_batch_v + (action_batch_m - self.action_emb_mean)**2) / self.action_emb_variance - 1 + self.action_emb_variance.log() - action_batch_v.log()).sum()) actor_loss_mean += self.kld_weight * kld actor_loss_mean.backward() self._maybe_run_optimizer(self.actor_network_optimizer, self.minibatches_per_step) # # Lastly, if applicable, optimize value network; minimizing MSE between # V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ] # if self.value_network is not None: state_value = self.value_network(state.float_features) if self.logged_action_uniform_prior: log_prob_a = torch.zeros_like(min_q_actor_value) target_value = min_q_actor_value else: with torch.no_grad(): log_prob_a = actor_output.log_prob log_prob_a = log_prob_a.clamp(-20.0, 20.0) target_value = (min_q_actor_value - self.entropy_temperature * log_prob_a) value_loss = F.mse_loss(state_value, target_value.detach()) value_loss.backward() self._maybe_run_optimizer(self.value_network_optimizer, self.minibatches_per_step) # Use the soft update rule to update the target networks if self.value_network is not None: self._maybe_soft_update( self.value_network, self.value_network_target, self.tau, self.minibatches_per_step, ) else: self._maybe_soft_update( self.q1_network, self.q1_network_target, self.tau, self.minibatches_per_step, ) if self.q2_network is not None: self._maybe_soft_update( self.q2_network, self.q2_network_target, self.tau, self.minibatches_per_step, ) # Logging at the end to schedule all the cuda operations first if (self.tensorboard_logging_freq != 0 and self.minibatch % self.tensorboard_logging_freq == 0): SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value) if self.q2_network: SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value) # pyre-fixme[16]: `SummaryWriterContext` has no attribute `add_scalar`. SummaryWriterContext.add_scalar("entropy_temperature", self.entropy_temperature) SummaryWriterContext.add_histogram("log_prob_a", log_prob_a) if self.value_network: SummaryWriterContext.add_histogram("value_network/target", target_value) SummaryWriterContext.add_histogram("q_network/next_state_value", next_state_value) SummaryWriterContext.add_histogram("q_network/target_q_value", target_q_value) SummaryWriterContext.add_histogram("actor/min_q_actor_value", min_q_actor_value) SummaryWriterContext.add_histogram("actor/action_log_prob", actor_output.log_prob) SummaryWriterContext.add_histogram("actor/loss", actor_loss) if self.add_kld_to_loss: SummaryWriterContext.add_histogram("kld/mean", action_batch_m) SummaryWriterContext.add_histogram("kld/var", action_batch_v) SummaryWriterContext.add_scalar("kld/kld", kld) self.loss_reporter.report( td_loss=float(q1_loss), reward_loss=None, logged_rewards=reward, model_values_on_logged_actions=q1_value, model_propensities=actor_output.log_prob.exp(), model_values=min_q_actor_value, )