def test_rescale_torch_tensor(self): rows, cols = 3, 5 original_tensor = torch.randint(low=10, high=40, size=(rows, cols)).float() prev_max_tensor = torch.ones(1, 5) * 40.0 prev_min_tensor = torch.ones(1, 5) * 10.0 new_min_tensor = torch.ones(1, 5) * -1.0 new_max_tensor = torch.ones(1, 5).float() print("Original tensor: ", original_tensor) rescaled_tensor = rescale_torch_tensor( original_tensor, new_min_tensor, new_max_tensor, prev_min_tensor, prev_max_tensor, ) print("Rescaled tensor: ", rescaled_tensor) reconstructed_original_tensor = rescale_torch_tensor( rescaled_tensor, prev_min_tensor, prev_max_tensor, new_min_tensor, new_max_tensor, ) print("Reconstructed Original tensor: ", reconstructed_original_tensor) comparison_tensor = torch.eq(original_tensor, reconstructed_original_tensor) self.assertTrue(torch.sum(comparison_tensor), rows * cols)
def test_rescale_torch_tensor(self): rows, cols = 3, 5 original_tensor = torch.randint(low=10, high=40, size=(rows, cols)).float() prev_max_tensor = torch.ones(1, 5) * 40.0 prev_min_tensor = torch.ones(1, 5) * 10.0 new_min_tensor = torch.ones(1, 5) * -1.0 new_max_tensor = torch.ones(1, 5).float() print("Original tensor: ", original_tensor) rescaled_tensor = rescale_torch_tensor( original_tensor, new_min_tensor, new_max_tensor, prev_min_tensor, prev_max_tensor, ) print("Rescaled tensor: ", rescaled_tensor) reconstructed_original_tensor = rescale_torch_tensor( rescaled_tensor, prev_min_tensor, prev_max_tensor, new_min_tensor, new_max_tensor, ) print("Reconstructed Original tensor: ", reconstructed_original_tensor) comparison_tensor = torch.eq(original_tensor, reconstructed_original_tensor) self.assertTrue(torch.sum(comparison_tensor), rows * cols)
def internal_prediction(self, states, noisy=False) -> np.ndarray: """ Returns list of actions output from actor network :param states states as list of states to produce actions for """ self.actor.eval() # TODO: Handle states being sequences state_examples = rlt.FeatureVector( float_features=torch.from_numpy(np.array(states)).type(self.dtype) ) action = self.actor(rlt.StateAction(state=state_examples, action=None)).action self.actor.train() action = rescale_torch_tensor( action, new_min=self.min_action_range_tensor_serving, new_max=self.max_action_range_tensor_serving, prev_min=self.min_action_range_tensor_training, prev_max=self.max_action_range_tensor_training, ) action = action.cpu().data.numpy() if noisy: action = [x + (self.noise.get_noise()) for x in action] return np.array(action, dtype=np.float32)
def internal_prediction(self, states, test=False): """ Returns list of actions output from actor network :param states states as list of states to produce actions for """ self.actor_network.eval() with torch.no_grad(): actions = self.actor_network( rlt.PreprocessedState.from_tensor(states)).action if not test: if self.minibatch < self.initial_exploration_ts: actions = (torch.rand_like(actions) * (self.max_action_range_tensor_training - self.min_action_range_tensor_training) + self.min_action_range_tensor_training) else: actions += torch.randn_like(actions) * self.exploration_noise # clamp actions to make sure actions are in the range clamped_actions = torch.max( torch.min(actions, self.max_action_range_tensor_training), self.min_action_range_tensor_training, ) rescaled_actions = rescale_torch_tensor( clamped_actions, new_min=self.min_action_range_tensor_serving, new_max=self.max_action_range_tensor_serving, prev_min=self.min_action_range_tensor_training, prev_max=self.max_action_range_tensor_training, ) self.actor_network.train() return rescaled_actions
def internal_prediction(self, states, noisy=False) -> np.ndarray: """ Returns list of actions output from actor network :param states states as list of states to produce actions for """ self.actor.eval() with torch.no_grad(): state_examples = Variable( torch.from_numpy(np.array(states)).type(self.dtype)) actions = self.actor(state_examples) self.actor.train() actions = rescale_torch_tensor( actions, new_min=self.min_action_range_tensor_serving, new_max=self.max_action_range_tensor_serving, prev_min=self.min_action_range_tensor_training, prev_max=self.max_action_range_tensor_training, ) actions = actions.cpu().data.numpy() if noisy: actions = [x + (self.noise.get_noise()) for x in actions] return np.array(actions, dtype=np.float32)
def _maybe_scale_action_in_train(self, action): if (self.min_action_range_tensor_training is not None and self.max_action_range_tensor_training is not None and self.min_action_range_tensor_serving is not None and self.max_action_range_tensor_serving is not None): action = rescale_torch_tensor( action, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, ) return action
def _maybe_scale_action_in_train(self, action): if (self.min_action_range_tensor_training is not None and self.max_action_range_tensor_training is not None and self.min_action_range_tensor_serving is not None and self.max_action_range_tensor_serving is not None): action = rlt.FeatureVector( rescale_torch_tensor( action.float_features, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, )) return action
def internal_prediction(self, states): """ Returns list of actions output from actor network :param states states as list of states to produce actions for """ self.actor_network.eval() actions = self.actor_network( rlt.StateInput(rlt.FeatureVector(float_features=states))) # clamp actions to make sure actions are in the range clamped_actions = torch.max( torch.min(actions.action, self.max_action_range_tensor_training), self.min_action_range_tensor_training, ) rescaled_actions = rescale_torch_tensor( clamped_actions, new_min=self.min_action_range_tensor_serving, new_max=self.max_action_range_tensor_serving, prev_min=self.min_action_range_tensor_training, prev_max=self.max_action_range_tensor_training, ) self.actor_network.train() return rescaled_actions
def internal_prediction(self, states, test=False): """ Returns list of actions output from actor network :param states states as list of states to produce actions for """ self.actor_network.eval() with torch.no_grad(): actions = self.actor_network( rlt.PreprocessedState.from_tensor(states)) # clamp actions to make sure actions are in the range clamped_actions = torch.max( torch.min(actions.action, self.max_action_range_tensor_training), self.min_action_range_tensor_training, ) rescaled_actions = rescale_torch_tensor( clamped_actions, new_min=self.min_action_range_tensor_serving, new_max=self.max_action_range_tensor_serving, prev_min=self.min_action_range_tensor_training, prev_max=self.max_action_range_tensor_training, ) self.actor_network.train() return rescaled_actions
def internal_prediction(self, states, noisy=False) -> np.ndarray: """ Returns list of actions output from actor network :param states states as list of states to produce actions for """ self.actor.eval() state_examples = torch.from_numpy(np.array(states)).type(self.dtype) actions = self.actor(state_examples) self.actor.train() actions = rescale_torch_tensor( actions, new_min=self.min_action_range_tensor_serving, new_max=self.max_action_range_tensor_serving, prev_min=self.min_action_range_tensor_training, prev_max=self.max_action_range_tensor_training, ) actions = actions.cpu().data.numpy() if noisy: actions = [x + (self.noise.get_noise()) for x in actions] return np.array(actions, dtype=np.float32)
def internal_prediction(self, states): """ Returns list of actions output from actor network :param states states as list of states to produce actions for """ self.actor_network.eval() state_examples = torch.from_numpy(np.array(states)).type(self.dtype) actions = self.actor_network( rlt.StateInput(rlt.FeatureVector(float_features=state_examples)) ) # clamp actions to make sure actions are in the range clamped_actions = torch.max( torch.min(actions.action, self.max_action_range_tensor_training), self.min_action_range_tensor_training, ) rescaled_actions = rescale_torch_tensor( clamped_actions, new_min=self.min_action_range_tensor_serving, new_max=self.max_action_range_tensor_serving, prev_min=self.min_action_range_tensor_training, prev_max=self.max_action_range_tensor_training, ) self.actor_network.train() return rescaled_actions
def train(self, training_samples: TrainingDataPage, evaluator=None, episode_values=None) -> None: self.minibatch += 1 states = Variable( torch.from_numpy(training_samples.states).type(self.dtype)) actions = Variable( torch.from_numpy(training_samples.actions).type(self.dtype)) # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh actions = rescale_torch_tensor( actions, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, ) rewards = Variable( torch.from_numpy(training_samples.rewards).type(self.dtype)) next_states = Variable( torch.from_numpy(training_samples.next_states).type(self.dtype)) time_diffs = torch.tensor(training_samples.time_diffs).type(self.dtype) discount_tensor = torch.tensor(np.full(len(rewards), self.gamma)).type(self.dtype) not_done_mask = Variable( torch.from_numpy(training_samples.not_terminals.astype(int))).type( self.dtype) # Optimize the critic network subject to mean squared error: # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2 q_s1_a1 = self.critic(torch.cat((states, actions), dim=1)) next_actions = self.actor_target(next_states) next_state_actions = torch.cat((next_states, next_actions), dim=1) q_s2_a2 = self.critic_target(next_state_actions).detach().squeeze() filtered_q_s2_a2 = not_done_mask * q_s2_a2 if self.use_seq_num_diff_as_time_diff: discount_tensor = discount_tensor.pow(time_diffs) if self.use_reward_burnin and self.minibatch < self.reward_burnin: target_q_values = rewards else: target_q_values = rewards + (discount_tensor * filtered_q_s2_a2) # compute loss and update the critic network critic_predictions = q_s1_a1.squeeze() loss_critic = self.q_network_loss(critic_predictions, target_q_values) self.critic_optimizer.zero_grad() loss_critic.backward() self.critic_optimizer.step() # Optimize the actor network subject to the following: # max sum(Q(s1, a1)) or min -sum(Q(s1, a1)) loss_actor = -self.critic( torch.cat((states, self.actor(states)), dim=1)).sum() self.actor_optimizer.zero_grad() loss_actor.backward() self.actor_optimizer.step() if self.use_reward_burnin and self.minibatch < self.reward_burnin: # Reward burnin: force target network self._soft_update(self.actor, self.actor_target, 1.0) self._soft_update(self.critic, self.critic_target, 1.0) else: # Use the soft update rule to update both target networks self._soft_update(self.actor, self.actor_target, self.tau) self._soft_update(self.critic, self.critic_target, self.tau) if evaluator is not None: evaluator.report( loss_critic.cpu().data.numpy(), None, None, None, episode_values, None, None, critic_predictions.cpu().data.numpy(), None, )
def train(self, training_batch: rlt.TrainingBatch) -> None: if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch() learning_input = training_batch.training_input self.minibatch += 1 state = learning_input.state # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh action = rlt.FeatureVector( rescale_torch_tensor( learning_input.action.float_features, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, ) ) rewards = learning_input.reward next_state = learning_input.next_state time_diffs = learning_input.time_diff discount_tensor = torch.full_like(rewards, self.gamma) not_done_mask = learning_input.not_terminal # Optimize the critic network subject to mean squared error: # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2 q_s1_a1 = self.critic.forward( rlt.StateAction(state=state, action=action) ).q_value next_action = rlt.FeatureVector( float_features=self.actor_target( rlt.StateAction(state=next_state, action=None) ).action ) q_s2_a2 = self.critic_target.forward( rlt.StateAction(state=next_state, action=next_action) ).q_value filtered_q_s2_a2 = not_done_mask.float() * q_s2_a2 if self.use_seq_num_diff_as_time_diff: discount_tensor = discount_tensor.pow(time_diffs) target_q_values = rewards + (discount_tensor * filtered_q_s2_a2) # compute loss and update the critic network critic_predictions = q_s1_a1 loss_critic = self.q_network_loss(critic_predictions, target_q_values.detach()) loss_critic_for_eval = loss_critic.detach() self.critic_optimizer.zero_grad() loss_critic.backward() self.critic_optimizer.step() # Optimize the actor network subject to the following: # max mean(Q(s1, a1)) or min -mean(Q(s1, a1)) actor_output = self.actor(rlt.StateAction(state=state, action=None)) loss_actor = -( self.critic.forward( rlt.StateAction( state=state, action=rlt.FeatureVector(float_features=actor_output.action), ) ).q_value.mean() ) # Zero out both the actor and critic gradients because we need # to backprop through the critic to get to the actor self.actor_optimizer.zero_grad() loss_actor.backward() self.actor_optimizer.step() # Use the soft update rule to update both target networks self._soft_update(self.actor, self.actor_target, self.tau) self._soft_update(self.critic, self.critic_target, self.tau) self.loss_reporter.report( td_loss=float(loss_critic_for_eval), reward_loss=None, model_values_on_logged_actions=critic_predictions, )
def train(self, training_batch, evaluator=None) -> None: """ IMPORTANT: the input action here is assumed to be preprocessed to match the range of the output of the actor. """ if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch( ) learning_input = training_batch.training_input self.minibatch += 1 state = learning_input.state action = learning_input.action reward = learning_input.reward discount = torch.full_like(reward, self.gamma) not_done_mask = learning_input.not_terminal if self._should_scale_action_in_train(): action = rlt.FeatureVector( rescale_torch_tensor( action.float_features, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, )) current_state_action = rlt.StateAction(state=state, action=action) q1_value = self.q1_network(current_state_action).q_value min_q_value = q1_value if self.q2_network: q2_value = self.q2_network(current_state_action).q_value min_q_value = torch.min(q1_value, q2_value) # Use the minimum as target, ensure no gradient going through min_q_value = min_q_value.detach() # # First, optimize value network; minimizing MSE between # V(s) & Q(s, a) - log(pi(a|s)) # state_value = self.value_network(state.float_features) # .q_value if self.logged_action_uniform_prior: log_prob_a = torch.zeros_like(min_q_value) target_value = min_q_value else: with torch.no_grad(): log_prob_a = self.actor_network.get_log_prob( state, action.float_features) log_prob_a = log_prob_a.clamp(-20.0, 20.0) target_value = min_q_value - self.entropy_temperature * log_prob_a value_loss = F.mse_loss(state_value, target_value) self.value_network_optimizer.zero_grad() value_loss.backward() self.value_network_optimizer.step() # # Second, optimize Q networks; minimizing MSE between # Q(s, a) & r + discount * V'(next_s) # with torch.no_grad(): next_state_value = (self.value_network_target( learning_input.next_state.float_features) * not_done_mask) if self.minibatch < self.reward_burnin: target_q_value = reward else: target_q_value = reward + discount * next_state_value q1_loss = F.mse_loss(q1_value, target_q_value) self.q1_network_optimizer.zero_grad() q1_loss.backward() self.q1_network_optimizer.step() if self.q2_network: q2_loss = F.mse_loss(q2_value, target_q_value) self.q2_network_optimizer.zero_grad() q2_loss.backward() self.q2_network_optimizer.step() # # Lastly, optimize the actor; minimizing KL-divergence between action propensity # & softmax of value. Due to reparameterization trick, it ends up being # log_prob(actor_action) - Q(s, actor_action) # actor_output = self.actor_network(rlt.StateInput(state=state)) state_actor_action = rlt.StateAction( state=state, action=rlt.FeatureVector(float_features=actor_output.action)) q1_actor_value = self.q1_network(state_actor_action).q_value min_q_actor_value = q1_actor_value if self.q2_network: q2_actor_value = self.q2_network(state_actor_action).q_value min_q_actor_value = torch.min(q1_actor_value, q2_actor_value) actor_loss = (self.entropy_temperature * actor_output.log_prob - min_q_actor_value) # Do this in 2 steps so we can log histogram of actor loss actor_loss_mean = actor_loss.mean() self.actor_network_optimizer.zero_grad() actor_loss_mean.backward() self.actor_network_optimizer.step() if self.minibatch < self.reward_burnin: # Reward burnin: force target network self._soft_update(self.value_network, self.value_network_target, 1.0) else: # Use the soft update rule to update both target networks self._soft_update(self.value_network, self.value_network_target, self.tau) # Logging at the end to schedule all the cuda operations first if (self.tensorboard_logging_freq is not None and self.minibatch % self.tensorboard_logging_freq == 0): SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value) if self.q2_network: SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value) SummaryWriterContext.add_histogram("log_prob_a", log_prob_a) SummaryWriterContext.add_histogram("value_network/target", target_value) SummaryWriterContext.add_histogram("q_network/next_state_value", next_state_value) SummaryWriterContext.add_histogram("q_network/target_q_value", target_q_value) SummaryWriterContext.add_histogram("actor/min_q_actor_value", min_q_actor_value) SummaryWriterContext.add_histogram("actor/action_log_prob", actor_output.log_prob) SummaryWriterContext.add_histogram("actor/loss", actor_loss) if evaluator is not None: cpe_stats = BatchStatsForCPE( td_loss=q1_loss.detach().cpu().numpy(), logged_rewards=reward.detach().cpu().numpy(), model_values_on_logged_actions=q1_value.detach().cpu().numpy(), model_propensities=actor_output.log_prob.exp().detach().cpu(). numpy(), model_values=min_q_actor_value.detach().cpu().numpy(), ) evaluator.report(cpe_stats)
def train(self, training_samples: TrainingDataPage) -> None: if self.minibatch == 0: # Assume that the tensors are the right shape after the first minibatch assert ( training_samples.states.shape[0] == self.minibatch_size ), "Invalid shape: " + str(training_samples.states.shape) assert ( training_samples.actions.shape[0] == self.minibatch_size ), "Invalid shape: " + str(training_samples.actions.shape) assert training_samples.rewards.shape == torch.Size( [self.minibatch_size, 1] ), "Invalid shape: " + str(training_samples.rewards.shape) assert ( training_samples.next_states.shape == training_samples.states.shape ), "Invalid shape: " + str(training_samples.next_states.shape) assert ( training_samples.not_terminal.shape == training_samples.rewards.shape ), "Invalid shape: " + str(training_samples.not_terminal.shape) if self.use_seq_num_diff_as_time_diff: assert ( training_samples.time_diffs.shape == training_samples.rewards.shape ), "Invalid shape: " + str(training_samples.time_diffs.shape) self.minibatch += 1 states = training_samples.states.detach().requires_grad_(True) actions = training_samples.actions.detach().requires_grad_(True) # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh actions = rescale_torch_tensor( actions, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, ) rewards = training_samples.rewards next_states = training_samples.next_states time_diffs = training_samples.time_diffs discount_tensor = torch.tensor(np.full(rewards.shape, self.gamma)).type( self.dtype ) not_done_mask = training_samples.not_terminal # Optimize the critic network subject to mean squared error: # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2 q_s1_a1 = self.critic.forward([states, actions]) next_actions = self.actor_target(next_states) q_s2_a2 = self.critic_target.forward([next_states, next_actions]) filtered_q_s2_a2 = not_done_mask * q_s2_a2 if self.use_seq_num_diff_as_time_diff: discount_tensor = discount_tensor.pow(time_diffs) if self.minibatch < self.reward_burnin: target_q_values = rewards else: target_q_values = rewards + (discount_tensor * filtered_q_s2_a2) # compute loss and update the critic network critic_predictions = q_s1_a1 loss_critic = self.q_network_loss(critic_predictions, target_q_values.detach()) loss_critic_for_eval = loss_critic.detach() self.critic_optimizer.zero_grad() loss_critic.backward() self.critic_optimizer.step() # Optimize the actor network subject to the following: # max mean(Q(s1, a1)) or min -mean(Q(s1, a1)) loss_actor = -( self.critic.forward([states.detach(), self.actor(states)]).mean() ) # Zero out both the actor and critic gradients because we need # to backprop through the critic to get to the actor self.actor_optimizer.zero_grad() loss_actor.backward() self.actor_optimizer.step() if self.minibatch < self.reward_burnin: # Reward burnin: force target network self._soft_update(self.actor, self.actor_target, 1.0) self._soft_update(self.critic, self.critic_target, 1.0) else: # Use the soft update rule to update both target networks self._soft_update(self.actor, self.actor_target, self.tau) self._soft_update(self.critic, self.critic_target, self.tau) self.loss_reporter.report( td_loss=float(loss_critic_for_eval), reward_loss=None, model_values_on_logged_actions=critic_predictions, )
def train(self, training_batch) -> None: """ IMPORTANT: the input action here is assumed to be preprocessed to match the range of the output of the actor. """ if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch( ) learning_input = training_batch.training_input self.minibatch += 1 state = learning_input.state action = learning_input.action reward = learning_input.reward discount = torch.full_like(reward, self.gamma) not_done_mask = learning_input.not_terminal if self._should_scale_action_in_train(): action = rlt.FeatureVector( rescale_torch_tensor( action.float_features, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, )) with torch.enable_grad(): # # First, optimize Q networks; minimizing MSE between # Q(s, a) & r + discount * V'(next_s) # current_state_action = rlt.StateAction(state=state, action=action) q1_value = self.q1_network(current_state_action).q_value if self.q2_network: q2_value = self.q2_network(current_state_action).q_value actor_output = self.actor_network(rlt.StateInput(state=state)) # Optimize Alpha if self.alpha_optimizer is not None: alpha_loss = -(self.log_alpha * (actor_output.log_prob + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.entropy_temperature = self.log_alpha.exp() with torch.no_grad(): if self.value_network is not None: next_state_value = self.value_network_target( learning_input.next_state.float_features) else: next_state_actor_output = self.actor_network( rlt.StateInput(state=learning_input.next_state)) next_state_actor_action = rlt.StateAction( state=learning_input.next_state, action=rlt.FeatureVector( float_features=next_state_actor_output.action), ) next_state_value = self.q1_network_target( next_state_actor_action).q_value if self.q2_network is not None: target_q2_value = self.q2_network_target( next_state_actor_action).q_value next_state_value = torch.min(next_state_value, target_q2_value) log_prob_a = self.actor_network.get_log_prob( learning_input.next_state, next_state_actor_output.action) log_prob_a = log_prob_a.clamp(-20.0, 20.0) next_state_value -= self.entropy_temperature * log_prob_a target_q_value = ( reward + discount * next_state_value * not_done_mask.float()) q1_loss = F.mse_loss(q1_value, target_q_value) q1_loss.backward() self._maybe_run_optimizer(self.q1_network_optimizer, self.minibatches_per_step) if self.q2_network: q2_loss = F.mse_loss(q2_value, target_q_value) q2_loss.backward() self._maybe_run_optimizer(self.q2_network_optimizer, self.minibatches_per_step) # # Second, optimize the actor; minimizing KL-divergence between action propensity # & softmax of value. Due to reparameterization trick, it ends up being # log_prob(actor_action) - Q(s, actor_action) # state_actor_action = rlt.StateAction( state=state, action=rlt.FeatureVector(float_features=actor_output.action), ) q1_actor_value = self.q1_network(state_actor_action).q_value min_q_actor_value = q1_actor_value if self.q2_network: q2_actor_value = self.q2_network(state_actor_action).q_value min_q_actor_value = torch.min(q1_actor_value, q2_actor_value) actor_loss = (self.entropy_temperature * actor_output.log_prob - min_q_actor_value) # Do this in 2 steps so we can log histogram of actor loss actor_loss_mean = actor_loss.mean() actor_loss_mean.backward() self._maybe_run_optimizer(self.actor_network_optimizer, self.minibatches_per_step) # # Lastly, if applicable, optimize value network; minimizing MSE between # V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ] # if self.value_network is not None: state_value = self.value_network(state.float_features) if self.logged_action_uniform_prior: log_prob_a = torch.zeros_like(min_q_actor_value) target_value = min_q_actor_value else: with torch.no_grad(): log_prob_a = actor_output.log_prob log_prob_a = log_prob_a.clamp(-20.0, 20.0) target_value = (min_q_actor_value - self.entropy_temperature * log_prob_a) value_loss = F.mse_loss(state_value, target_value.detach()) value_loss.backward() self._maybe_run_optimizer(self.value_network_optimizer, self.minibatches_per_step) # Use the soft update rule to update the target networks if self.value_network is not None: self._maybe_soft_update( self.value_network, self.value_network_target, self.tau, self.minibatches_per_step, ) else: self._maybe_soft_update( self.q1_network, self.q1_network_target, self.tau, self.minibatches_per_step, ) if self.q2_network is not None: self._maybe_soft_update( self.q2_network, self.q2_network_target, self.tau, self.minibatches_per_step, ) # Logging at the end to schedule all the cuda operations first if (self.tensorboard_logging_freq is not None and self.minibatch % self.tensorboard_logging_freq == 0): SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value) if self.q2_network: SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value) SummaryWriterContext.add_histogram("log_prob_a", log_prob_a) if self.value_network: SummaryWriterContext.add_histogram("value_network/target", target_value) SummaryWriterContext.add_histogram("q_network/next_state_value", next_state_value) SummaryWriterContext.add_histogram("q_network/target_q_value", target_q_value) SummaryWriterContext.add_histogram("actor/min_q_actor_value", min_q_actor_value) SummaryWriterContext.add_histogram("actor/action_log_prob", actor_output.log_prob) SummaryWriterContext.add_histogram("actor/loss", actor_loss) self.loss_reporter.report( td_loss=float(q1_loss), reward_loss=None, logged_rewards=reward, model_values_on_logged_actions=q1_value, model_propensities=actor_output.log_prob.exp(), model_values=min_q_actor_value, )
def train(self, training_samples: TrainingDataPage, evaluator=None) -> None: if self.minibatch == 0: # Assume that the tensors are the right shape after the first minibatch assert ( training_samples.states.shape[0] == self.minibatch_size ), "Invalid shape: " + str(training_samples.states.shape) assert ( training_samples.actions.shape[0] == self.minibatch_size ), "Invalid shape: " + str(training_samples.actions.shape) assert training_samples.rewards.shape == torch.Size( [self.minibatch_size, 1] ), "Invalid shape: " + str(training_samples.rewards.shape) assert ( training_samples.next_states.shape == training_samples.states.shape ), "Invalid shape: " + str(training_samples.next_states.shape) assert ( training_samples.not_terminals.shape == training_samples.rewards.shape ), "Invalid shape: " + str(training_samples.not_terminals.shape) if self.use_seq_num_diff_as_time_diff: assert ( training_samples.time_diffs.shape == training_samples.rewards.shape ), "Invalid shape: " + str(training_samples.time_diffs.shape) self.minibatch += 1 states = training_samples.states.detach().requires_grad_(True) actions = training_samples.actions.detach().requires_grad_(True) # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh actions = rescale_torch_tensor( actions, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, ) rewards = training_samples.rewards next_states = training_samples.next_states time_diffs = training_samples.time_diffs discount_tensor = torch.tensor(np.full(rewards.shape, self.gamma)).type( self.dtype ) not_done_mask = training_samples.not_terminals # Optimize the critic network subject to mean squared error: # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2 q_s1_a1 = self.critic(torch.cat((states, actions), dim=1)) next_actions = self.actor_target(next_states) next_state_actions = torch.cat((next_states, next_actions), dim=1) q_s2_a2 = self.critic_target(next_state_actions) filtered_q_s2_a2 = not_done_mask * q_s2_a2 if self.use_seq_num_diff_as_time_diff: discount_tensor = discount_tensor.pow(time_diffs) if self.minibatch < self.reward_burnin: target_q_values = rewards else: target_q_values = rewards + (discount_tensor * filtered_q_s2_a2) # compute loss and update the critic network critic_predictions = q_s1_a1 loss_critic = self.q_network_loss(critic_predictions, target_q_values.detach()) loss_critic_for_eval = loss_critic.detach() self.critic_optimizer.zero_grad() loss_critic.backward() self.critic_optimizer.step() # Optimize the actor network subject to the following: # max mean(Q(s1, a1)) or min -mean(Q(s1, a1)) loss_actor = -self.critic(torch.cat((states, self.actor(states)), dim=1)).mean() self.actor_optimizer.zero_grad() loss_actor.backward() self.actor_optimizer.step() if self.minibatch < self.reward_burnin: # Reward burnin: force target network self._soft_update(self.actor, self.actor_target, 1.0) self._soft_update(self.critic, self.critic_target, 1.0) else: # Use the soft update rule to update both target networks self._soft_update(self.actor, self.actor_target, self.tau) self._soft_update(self.critic, self.critic_target, self.tau) if evaluator is not None: cpe_stats = BatchStatsForCPE(td_loss=loss_critic_for_eval.cpu().numpy()) evaluator.report(cpe_stats)
def train(self, training_batch) -> None: """ IMPORTANT: the input action here is assumed to be preprocessed to match the range of the output of the actor. """ if hasattr(training_batch, "as_parametric_sarsa_training_batch"): training_batch = training_batch.as_parametric_sarsa_training_batch() learning_input = training_batch.training_input self.minibatch += 1 state = learning_input.state action = learning_input.action reward = learning_input.reward discount = torch.full_like(reward, self.gamma) not_done_mask = learning_input.not_terminal if self._should_scale_action_in_train(): action = rlt.FeatureVector( rescale_torch_tensor( action.float_features, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, ) ) current_state_action = rlt.StateAction(state=state, action=action) q1_value = self.q1_network(current_state_action).q_value min_q_value = q1_value if self.q2_network: q2_value = self.q2_network(current_state_action).q_value min_q_value = torch.min(q1_value, q2_value) # Use the minimum as target, ensure no gradient going through min_q_value = min_q_value.detach() # # First, optimize value network; minimizing MSE between # V(s) & Q(s, a) - log(pi(a|s)) # state_value = self.value_network(state.float_features) # .q_value if self.logged_action_uniform_prior: log_prob_a = torch.zeros_like(min_q_value) target_value = min_q_value else: with torch.no_grad(): log_prob_a = self.actor_network.get_log_prob( state, action.float_features ) log_prob_a = log_prob_a.clamp(-20.0, 20.0) target_value = min_q_value - self.entropy_temperature * log_prob_a value_loss = F.mse_loss(state_value, target_value) self.value_network_optimizer.zero_grad() value_loss.backward() self.value_network_optimizer.step() # # Second, optimize Q networks; minimizing MSE between # Q(s, a) & r + discount * V'(next_s) # with torch.no_grad(): next_state_value = ( self.value_network_target(learning_input.next_state.float_features) * not_done_mask.float() ) if self.minibatch < self.reward_burnin: target_q_value = reward else: target_q_value = reward + discount * next_state_value q1_loss = F.mse_loss(q1_value, target_q_value) self.q1_network_optimizer.zero_grad() q1_loss.backward() self.q1_network_optimizer.step() if self.q2_network: q2_loss = F.mse_loss(q2_value, target_q_value) self.q2_network_optimizer.zero_grad() q2_loss.backward() self.q2_network_optimizer.step() # # Lastly, optimize the actor; minimizing KL-divergence between action propensity # & softmax of value. Due to reparameterization trick, it ends up being # log_prob(actor_action) - Q(s, actor_action) # actor_output = self.actor_network(rlt.StateInput(state=state)) state_actor_action = rlt.StateAction( state=state, action=rlt.FeatureVector(float_features=actor_output.action) ) q1_actor_value = self.q1_network(state_actor_action).q_value min_q_actor_value = q1_actor_value if self.q2_network: q2_actor_value = self.q2_network(state_actor_action).q_value min_q_actor_value = torch.min(q1_actor_value, q2_actor_value) actor_loss = ( self.entropy_temperature * actor_output.log_prob - min_q_actor_value ) # Do this in 2 steps so we can log histogram of actor loss actor_loss_mean = actor_loss.mean() self.actor_network_optimizer.zero_grad() actor_loss_mean.backward() self.actor_network_optimizer.step() if self.minibatch < self.reward_burnin: # Reward burnin: force target network self._soft_update(self.value_network, self.value_network_target, 1.0) else: # Use the soft update rule to update both target networks self._soft_update(self.value_network, self.value_network_target, self.tau) # Logging at the end to schedule all the cuda operations first if ( self.tensorboard_logging_freq is not None and self.minibatch % self.tensorboard_logging_freq == 0 ): SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value) if self.q2_network: SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value) SummaryWriterContext.add_histogram("log_prob_a", log_prob_a) SummaryWriterContext.add_histogram("value_network/target", target_value) SummaryWriterContext.add_histogram( "q_network/next_state_value", next_state_value ) SummaryWriterContext.add_histogram( "q_network/target_q_value", target_q_value ) SummaryWriterContext.add_histogram( "actor/min_q_actor_value", min_q_actor_value ) SummaryWriterContext.add_histogram( "actor/action_log_prob", actor_output.log_prob ) SummaryWriterContext.add_histogram("actor/loss", actor_loss) self.loss_reporter.report( td_loss=float(q1_loss), reward_loss=None, logged_rewards=reward, model_values_on_logged_actions=q1_value, model_propensities=actor_output.log_prob.exp(), model_values=min_q_actor_value, )