Exemplo n.º 1
0
    def test_rescale_torch_tensor(self):
        rows, cols = 3, 5
        original_tensor = torch.randint(low=10, high=40,
                                        size=(rows, cols)).float()
        prev_max_tensor = torch.ones(1, 5) * 40.0
        prev_min_tensor = torch.ones(1, 5) * 10.0
        new_min_tensor = torch.ones(1, 5) * -1.0
        new_max_tensor = torch.ones(1, 5).float()

        print("Original tensor: ", original_tensor)
        rescaled_tensor = rescale_torch_tensor(
            original_tensor,
            new_min_tensor,
            new_max_tensor,
            prev_min_tensor,
            prev_max_tensor,
        )
        print("Rescaled tensor: ", rescaled_tensor)
        reconstructed_original_tensor = rescale_torch_tensor(
            rescaled_tensor,
            prev_min_tensor,
            prev_max_tensor,
            new_min_tensor,
            new_max_tensor,
        )
        print("Reconstructed Original tensor: ", reconstructed_original_tensor)

        comparison_tensor = torch.eq(original_tensor,
                                     reconstructed_original_tensor)
        self.assertTrue(torch.sum(comparison_tensor), rows * cols)
Exemplo n.º 2
0
    def test_rescale_torch_tensor(self):
        rows, cols = 3, 5
        original_tensor = torch.randint(low=10, high=40, size=(rows, cols)).float()
        prev_max_tensor = torch.ones(1, 5) * 40.0
        prev_min_tensor = torch.ones(1, 5) * 10.0
        new_min_tensor = torch.ones(1, 5) * -1.0
        new_max_tensor = torch.ones(1, 5).float()

        print("Original tensor: ", original_tensor)
        rescaled_tensor = rescale_torch_tensor(
            original_tensor,
            new_min_tensor,
            new_max_tensor,
            prev_min_tensor,
            prev_max_tensor,
        )
        print("Rescaled tensor: ", rescaled_tensor)
        reconstructed_original_tensor = rescale_torch_tensor(
            rescaled_tensor,
            prev_min_tensor,
            prev_max_tensor,
            new_min_tensor,
            new_max_tensor,
        )
        print("Reconstructed Original tensor: ", reconstructed_original_tensor)

        comparison_tensor = torch.eq(original_tensor, reconstructed_original_tensor)
        self.assertTrue(torch.sum(comparison_tensor), rows * cols)
Exemplo n.º 3
0
    def internal_prediction(self, states, noisy=False) -> np.ndarray:
        """ Returns list of actions output from actor network
        :param states states as list of states to produce actions for
        """
        self.actor.eval()
        # TODO: Handle states being sequences
        state_examples = rlt.FeatureVector(
            float_features=torch.from_numpy(np.array(states)).type(self.dtype)
        )
        action = self.actor(rlt.StateAction(state=state_examples, action=None)).action

        self.actor.train()

        action = rescale_torch_tensor(
            action,
            new_min=self.min_action_range_tensor_serving,
            new_max=self.max_action_range_tensor_serving,
            prev_min=self.min_action_range_tensor_training,
            prev_max=self.max_action_range_tensor_training,
        )

        action = action.cpu().data.numpy()
        if noisy:
            action = [x + (self.noise.get_noise()) for x in action]

        return np.array(action, dtype=np.float32)
Exemplo n.º 4
0
    def internal_prediction(self, states, test=False):
        """ Returns list of actions output from actor network
        :param states states as list of states to produce actions for
        """
        self.actor_network.eval()
        with torch.no_grad():
            actions = self.actor_network(
                rlt.PreprocessedState.from_tensor(states)).action

        if not test:
            if self.minibatch < self.initial_exploration_ts:
                actions = (torch.rand_like(actions) *
                           (self.max_action_range_tensor_training -
                            self.min_action_range_tensor_training) +
                           self.min_action_range_tensor_training)
            else:
                actions += torch.randn_like(actions) * self.exploration_noise

        # clamp actions to make sure actions are in the range
        clamped_actions = torch.max(
            torch.min(actions, self.max_action_range_tensor_training),
            self.min_action_range_tensor_training,
        )
        rescaled_actions = rescale_torch_tensor(
            clamped_actions,
            new_min=self.min_action_range_tensor_serving,
            new_max=self.max_action_range_tensor_serving,
            prev_min=self.min_action_range_tensor_training,
            prev_max=self.max_action_range_tensor_training,
        )

        self.actor_network.train()
        return rescaled_actions
Exemplo n.º 5
0
    def internal_prediction(self, states, noisy=False) -> np.ndarray:
        """ Returns list of actions output from actor network
        :param states states as list of states to produce actions for
        """
        self.actor.eval()
        with torch.no_grad():
            state_examples = Variable(
                torch.from_numpy(np.array(states)).type(self.dtype))
            actions = self.actor(state_examples)

        self.actor.train()

        actions = rescale_torch_tensor(
            actions,
            new_min=self.min_action_range_tensor_serving,
            new_max=self.max_action_range_tensor_serving,
            prev_min=self.min_action_range_tensor_training,
            prev_max=self.max_action_range_tensor_training,
        )

        actions = actions.cpu().data.numpy()
        if noisy:
            actions = [x + (self.noise.get_noise()) for x in actions]

        return np.array(actions, dtype=np.float32)
Exemplo n.º 6
0
 def _maybe_scale_action_in_train(self, action):
     if (self.min_action_range_tensor_training is not None
             and self.max_action_range_tensor_training is not None
             and self.min_action_range_tensor_serving is not None
             and self.max_action_range_tensor_serving is not None):
         action = rescale_torch_tensor(
             action,
             new_min=self.min_action_range_tensor_training,
             new_max=self.max_action_range_tensor_training,
             prev_min=self.min_action_range_tensor_serving,
             prev_max=self.max_action_range_tensor_serving,
         )
     return action
Exemplo n.º 7
0
 def _maybe_scale_action_in_train(self, action):
     if (self.min_action_range_tensor_training is not None
             and self.max_action_range_tensor_training is not None
             and self.min_action_range_tensor_serving is not None
             and self.max_action_range_tensor_serving is not None):
         action = rlt.FeatureVector(
             rescale_torch_tensor(
                 action.float_features,
                 new_min=self.min_action_range_tensor_training,
                 new_max=self.max_action_range_tensor_training,
                 prev_min=self.min_action_range_tensor_serving,
                 prev_max=self.max_action_range_tensor_serving,
             ))
     return action
Exemplo n.º 8
0
    def internal_prediction(self, states):
        """ Returns list of actions output from actor network
        :param states states as list of states to produce actions for
        """
        self.actor_network.eval()
        actions = self.actor_network(
            rlt.StateInput(rlt.FeatureVector(float_features=states)))
        # clamp actions to make sure actions are in the range
        clamped_actions = torch.max(
            torch.min(actions.action, self.max_action_range_tensor_training),
            self.min_action_range_tensor_training,
        )
        rescaled_actions = rescale_torch_tensor(
            clamped_actions,
            new_min=self.min_action_range_tensor_serving,
            new_max=self.max_action_range_tensor_serving,
            prev_min=self.min_action_range_tensor_training,
            prev_max=self.max_action_range_tensor_training,
        )

        self.actor_network.train()
        return rescaled_actions
Exemplo n.º 9
0
    def internal_prediction(self, states, test=False):
        """ Returns list of actions output from actor network
        :param states states as list of states to produce actions for
        """
        self.actor_network.eval()
        with torch.no_grad():
            actions = self.actor_network(
                rlt.PreprocessedState.from_tensor(states))
        # clamp actions to make sure actions are in the range
        clamped_actions = torch.max(
            torch.min(actions.action, self.max_action_range_tensor_training),
            self.min_action_range_tensor_training,
        )
        rescaled_actions = rescale_torch_tensor(
            clamped_actions,
            new_min=self.min_action_range_tensor_serving,
            new_max=self.max_action_range_tensor_serving,
            prev_min=self.min_action_range_tensor_training,
            prev_max=self.max_action_range_tensor_training,
        )

        self.actor_network.train()
        return rescaled_actions
Exemplo n.º 10
0
    def internal_prediction(self, states, noisy=False) -> np.ndarray:
        """ Returns list of actions output from actor network
        :param states states as list of states to produce actions for
        """
        self.actor.eval()
        state_examples = torch.from_numpy(np.array(states)).type(self.dtype)
        actions = self.actor(state_examples)

        self.actor.train()

        actions = rescale_torch_tensor(
            actions,
            new_min=self.min_action_range_tensor_serving,
            new_max=self.max_action_range_tensor_serving,
            prev_min=self.min_action_range_tensor_training,
            prev_max=self.max_action_range_tensor_training,
        )

        actions = actions.cpu().data.numpy()
        if noisy:
            actions = [x + (self.noise.get_noise()) for x in actions]

        return np.array(actions, dtype=np.float32)
Exemplo n.º 11
0
    def internal_prediction(self, states):
        """ Returns list of actions output from actor network
        :param states states as list of states to produce actions for
        """
        self.actor_network.eval()
        state_examples = torch.from_numpy(np.array(states)).type(self.dtype)
        actions = self.actor_network(
            rlt.StateInput(rlt.FeatureVector(float_features=state_examples))
        )
        # clamp actions to make sure actions are in the range
        clamped_actions = torch.max(
            torch.min(actions.action, self.max_action_range_tensor_training),
            self.min_action_range_tensor_training,
        )
        rescaled_actions = rescale_torch_tensor(
            clamped_actions,
            new_min=self.min_action_range_tensor_serving,
            new_max=self.max_action_range_tensor_serving,
            prev_min=self.min_action_range_tensor_training,
            prev_max=self.max_action_range_tensor_training,
        )

        self.actor_network.train()
        return rescaled_actions
Exemplo n.º 12
0
    def train(self,
              training_samples: TrainingDataPage,
              evaluator=None,
              episode_values=None) -> None:

        self.minibatch += 1
        states = Variable(
            torch.from_numpy(training_samples.states).type(self.dtype))
        actions = Variable(
            torch.from_numpy(training_samples.actions).type(self.dtype))
        # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh

        actions = rescale_torch_tensor(
            actions,
            new_min=self.min_action_range_tensor_training,
            new_max=self.max_action_range_tensor_training,
            prev_min=self.min_action_range_tensor_serving,
            prev_max=self.max_action_range_tensor_serving,
        )
        rewards = Variable(
            torch.from_numpy(training_samples.rewards).type(self.dtype))
        next_states = Variable(
            torch.from_numpy(training_samples.next_states).type(self.dtype))
        time_diffs = torch.tensor(training_samples.time_diffs).type(self.dtype)
        discount_tensor = torch.tensor(np.full(len(rewards),
                                               self.gamma)).type(self.dtype)
        not_done_mask = Variable(
            torch.from_numpy(training_samples.not_terminals.astype(int))).type(
                self.dtype)

        # Optimize the critic network subject to mean squared error:
        # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2
        q_s1_a1 = self.critic(torch.cat((states, actions), dim=1))
        next_actions = self.actor_target(next_states)

        next_state_actions = torch.cat((next_states, next_actions), dim=1)
        q_s2_a2 = self.critic_target(next_state_actions).detach().squeeze()
        filtered_q_s2_a2 = not_done_mask * q_s2_a2

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(time_diffs)

        if self.use_reward_burnin and self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            target_q_values = rewards + (discount_tensor * filtered_q_s2_a2)

        # compute loss and update the critic network
        critic_predictions = q_s1_a1.squeeze()
        loss_critic = self.q_network_loss(critic_predictions, target_q_values)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # Optimize the actor network subject to the following:
        # max sum(Q(s1, a1)) or min -sum(Q(s1, a1))
        loss_actor = -self.critic(
            torch.cat((states, self.actor(states)), dim=1)).sum()
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        if self.use_reward_burnin and self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.actor, self.actor_target, 1.0)
            self._soft_update(self.critic, self.critic_target, 1.0)
        else:
            # Use the soft update rule to update both target networks
            self._soft_update(self.actor, self.actor_target, self.tau)
            self._soft_update(self.critic, self.critic_target, self.tau)

        if evaluator is not None:
            evaluator.report(
                loss_critic.cpu().data.numpy(),
                None,
                None,
                None,
                episode_values,
                None,
                None,
                critic_predictions.cpu().data.numpy(),
                None,
            )
Exemplo n.º 13
0
    def train(self, training_batch: rlt.TrainingBatch) -> None:
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state

        # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh
        action = rlt.FeatureVector(
            rescale_torch_tensor(
                learning_input.action.float_features,
                new_min=self.min_action_range_tensor_training,
                new_max=self.max_action_range_tensor_training,
                prev_min=self.min_action_range_tensor_serving,
                prev_max=self.max_action_range_tensor_serving,
            )
        )

        rewards = learning_input.reward
        next_state = learning_input.next_state
        time_diffs = learning_input.time_diff
        discount_tensor = torch.full_like(rewards, self.gamma)
        not_done_mask = learning_input.not_terminal

        # Optimize the critic network subject to mean squared error:
        # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2
        q_s1_a1 = self.critic.forward(
            rlt.StateAction(state=state, action=action)
        ).q_value
        next_action = rlt.FeatureVector(
            float_features=self.actor_target(
                rlt.StateAction(state=next_state, action=None)
            ).action
        )

        q_s2_a2 = self.critic_target.forward(
            rlt.StateAction(state=next_state, action=next_action)
        ).q_value
        filtered_q_s2_a2 = not_done_mask.float() * q_s2_a2

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(time_diffs)

        target_q_values = rewards + (discount_tensor * filtered_q_s2_a2)

        # compute loss and update the critic network
        critic_predictions = q_s1_a1
        loss_critic = self.q_network_loss(critic_predictions, target_q_values.detach())
        loss_critic_for_eval = loss_critic.detach()
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # Optimize the actor network subject to the following:
        # max mean(Q(s1, a1)) or min -mean(Q(s1, a1))
        actor_output = self.actor(rlt.StateAction(state=state, action=None))
        loss_actor = -(
            self.critic.forward(
                rlt.StateAction(
                    state=state,
                    action=rlt.FeatureVector(float_features=actor_output.action),
                )
            ).q_value.mean()
        )

        # Zero out both the actor and critic gradients because we need
        #   to backprop through the critic to get to the actor
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        # Use the soft update rule to update both target networks
        self._soft_update(self.actor, self.actor_target, self.tau)
        self._soft_update(self.critic, self.critic_target, self.tau)

        self.loss_reporter.report(
            td_loss=float(loss_critic_for_eval),
            reward_loss=None,
            model_values_on_logged_actions=critic_predictions,
        )
Exemplo n.º 14
0
    def train(self, training_batch, evaluator=None) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch(
            )

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self._should_scale_action_in_train():
            action = rlt.FeatureVector(
                rescale_torch_tensor(
                    action.float_features,
                    new_min=self.min_action_range_tensor_training,
                    new_max=self.max_action_range_tensor_training,
                    prev_min=self.min_action_range_tensor_serving,
                    prev_max=self.max_action_range_tensor_serving,
                ))

        current_state_action = rlt.StateAction(state=state, action=action)

        q1_value = self.q1_network(current_state_action).q_value
        min_q_value = q1_value

        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value
            min_q_value = torch.min(q1_value, q2_value)

        # Use the minimum as target, ensure no gradient going through
        min_q_value = min_q_value.detach()

        #
        # First, optimize value network; minimizing MSE between
        # V(s) & Q(s, a) - log(pi(a|s))
        #

        state_value = self.value_network(state.float_features)  # .q_value

        if self.logged_action_uniform_prior:
            log_prob_a = torch.zeros_like(min_q_value)
            target_value = min_q_value
        else:
            with torch.no_grad():
                log_prob_a = self.actor_network.get_log_prob(
                    state, action.float_features)
                log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                target_value = min_q_value - self.entropy_temperature * log_prob_a

        value_loss = F.mse_loss(state_value, target_value)
        self.value_network_optimizer.zero_grad()
        value_loss.backward()
        self.value_network_optimizer.step()

        #
        # Second, optimize Q networks; minimizing MSE between
        # Q(s, a) & r + discount * V'(next_s)
        #

        with torch.no_grad():
            next_state_value = (self.value_network_target(
                learning_input.next_state.float_features) * not_done_mask)

            if self.minibatch < self.reward_burnin:
                target_q_value = reward
            else:
                target_q_value = reward + discount * next_state_value

        q1_loss = F.mse_loss(q1_value, target_q_value)
        self.q1_network_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_network_optimizer.step()
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            self.q2_network_optimizer.zero_grad()
            q2_loss.backward()
            self.q2_network_optimizer.step()

        #
        # Lastly, optimize the actor; minimizing KL-divergence between action propensity
        # & softmax of value. Due to reparameterization trick, it ends up being
        # log_prob(actor_action) - Q(s, actor_action)
        #

        actor_output = self.actor_network(rlt.StateInput(state=state))

        state_actor_action = rlt.StateAction(
            state=state,
            action=rlt.FeatureVector(float_features=actor_output.action))
        q1_actor_value = self.q1_network(state_actor_action).q_value
        min_q_actor_value = q1_actor_value
        if self.q2_network:
            q2_actor_value = self.q2_network(state_actor_action).q_value
            min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

        actor_loss = (self.entropy_temperature * actor_output.log_prob -
                      min_q_actor_value)
        # Do this in 2 steps so we can log histogram of actor loss
        actor_loss_mean = actor_loss.mean()
        self.actor_network_optimizer.zero_grad()
        actor_loss_mean.backward()
        self.actor_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.value_network, self.value_network_target,
                              1.0)
        else:
            # Use the soft update rule to update both target networks
            self._soft_update(self.value_network, self.value_network_target,
                              self.tau)

        # Logging at the end to schedule all the cuda operations first
        if (self.tensorboard_logging_freq is not None
                and self.minibatch % self.tensorboard_logging_freq == 0):
            SummaryWriterContext.add_histogram("q1/logged_state_value",
                                               q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value",
                                                   q2_value)

            SummaryWriterContext.add_histogram("log_prob_a", log_prob_a)
            SummaryWriterContext.add_histogram("value_network/target",
                                               target_value)
            SummaryWriterContext.add_histogram("q_network/next_state_value",
                                               next_state_value)
            SummaryWriterContext.add_histogram("q_network/target_q_value",
                                               target_q_value)
            SummaryWriterContext.add_histogram("actor/min_q_actor_value",
                                               min_q_actor_value)
            SummaryWriterContext.add_histogram("actor/action_log_prob",
                                               actor_output.log_prob)
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        if evaluator is not None:
            cpe_stats = BatchStatsForCPE(
                td_loss=q1_loss.detach().cpu().numpy(),
                logged_rewards=reward.detach().cpu().numpy(),
                model_values_on_logged_actions=q1_value.detach().cpu().numpy(),
                model_propensities=actor_output.log_prob.exp().detach().cpu().
                numpy(),
                model_values=min_q_actor_value.detach().cpu().numpy(),
            )
            evaluator.report(cpe_stats)
Exemplo n.º 15
0
    def train(self, training_samples: TrainingDataPage) -> None:
        if self.minibatch == 0:
            # Assume that the tensors are the right shape after the first minibatch
            assert (
                training_samples.states.shape[0] == self.minibatch_size
            ), "Invalid shape: " + str(training_samples.states.shape)
            assert (
                training_samples.actions.shape[0] == self.minibatch_size
            ), "Invalid shape: " + str(training_samples.actions.shape)
            assert training_samples.rewards.shape == torch.Size(
                [self.minibatch_size, 1]
            ), "Invalid shape: " + str(training_samples.rewards.shape)
            assert (
                training_samples.next_states.shape == training_samples.states.shape
            ), "Invalid shape: " + str(training_samples.next_states.shape)
            assert (
                training_samples.not_terminal.shape == training_samples.rewards.shape
            ), "Invalid shape: " + str(training_samples.not_terminal.shape)
            if self.use_seq_num_diff_as_time_diff:
                assert (
                    training_samples.time_diffs.shape == training_samples.rewards.shape
                ), "Invalid shape: " + str(training_samples.time_diffs.shape)

        self.minibatch += 1
        states = training_samples.states.detach().requires_grad_(True)
        actions = training_samples.actions.detach().requires_grad_(True)

        # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh
        actions = rescale_torch_tensor(
            actions,
            new_min=self.min_action_range_tensor_training,
            new_max=self.max_action_range_tensor_training,
            prev_min=self.min_action_range_tensor_serving,
            prev_max=self.max_action_range_tensor_serving,
        )
        rewards = training_samples.rewards
        next_states = training_samples.next_states
        time_diffs = training_samples.time_diffs
        discount_tensor = torch.tensor(np.full(rewards.shape, self.gamma)).type(
            self.dtype
        )
        not_done_mask = training_samples.not_terminal

        # Optimize the critic network subject to mean squared error:
        # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2
        q_s1_a1 = self.critic.forward([states, actions])
        next_actions = self.actor_target(next_states)

        q_s2_a2 = self.critic_target.forward([next_states, next_actions])
        filtered_q_s2_a2 = not_done_mask * q_s2_a2

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(time_diffs)

        if self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            target_q_values = rewards + (discount_tensor * filtered_q_s2_a2)

        # compute loss and update the critic network
        critic_predictions = q_s1_a1
        loss_critic = self.q_network_loss(critic_predictions, target_q_values.detach())
        loss_critic_for_eval = loss_critic.detach()
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # Optimize the actor network subject to the following:
        # max mean(Q(s1, a1)) or min -mean(Q(s1, a1))
        loss_actor = -(
            self.critic.forward([states.detach(), self.actor(states)]).mean()
        )

        # Zero out both the actor and critic gradients because we need
        #   to backprop through the critic to get to the actor
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.actor, self.actor_target, 1.0)
            self._soft_update(self.critic, self.critic_target, 1.0)
        else:
            # Use the soft update rule to update both target networks
            self._soft_update(self.actor, self.actor_target, self.tau)
            self._soft_update(self.critic, self.critic_target, self.tau)

        self.loss_reporter.report(
            td_loss=float(loss_critic_for_eval),
            reward_loss=None,
            model_values_on_logged_actions=critic_predictions,
        )
Exemplo n.º 16
0
    def train(self, training_batch) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch(
            )

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self._should_scale_action_in_train():
            action = rlt.FeatureVector(
                rescale_torch_tensor(
                    action.float_features,
                    new_min=self.min_action_range_tensor_training,
                    new_max=self.max_action_range_tensor_training,
                    prev_min=self.min_action_range_tensor_serving,
                    prev_max=self.max_action_range_tensor_serving,
                ))

        with torch.enable_grad():
            #
            # First, optimize Q networks; minimizing MSE between
            # Q(s, a) & r + discount * V'(next_s)
            #

            current_state_action = rlt.StateAction(state=state, action=action)
            q1_value = self.q1_network(current_state_action).q_value
            if self.q2_network:
                q2_value = self.q2_network(current_state_action).q_value
            actor_output = self.actor_network(rlt.StateInput(state=state))

            # Optimize Alpha
            if self.alpha_optimizer is not None:
                alpha_loss = -(self.log_alpha *
                               (actor_output.log_prob +
                                self.target_entropy).detach()).mean()
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                self.entropy_temperature = self.log_alpha.exp()

            with torch.no_grad():
                if self.value_network is not None:
                    next_state_value = self.value_network_target(
                        learning_input.next_state.float_features)
                else:
                    next_state_actor_output = self.actor_network(
                        rlt.StateInput(state=learning_input.next_state))
                    next_state_actor_action = rlt.StateAction(
                        state=learning_input.next_state,
                        action=rlt.FeatureVector(
                            float_features=next_state_actor_output.action),
                    )
                    next_state_value = self.q1_network_target(
                        next_state_actor_action).q_value

                    if self.q2_network is not None:
                        target_q2_value = self.q2_network_target(
                            next_state_actor_action).q_value
                        next_state_value = torch.min(next_state_value,
                                                     target_q2_value)

                    log_prob_a = self.actor_network.get_log_prob(
                        learning_input.next_state,
                        next_state_actor_output.action)
                    log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                    next_state_value -= self.entropy_temperature * log_prob_a

                target_q_value = (
                    reward +
                    discount * next_state_value * not_done_mask.float())

            q1_loss = F.mse_loss(q1_value, target_q_value)
            q1_loss.backward()
            self._maybe_run_optimizer(self.q1_network_optimizer,
                                      self.minibatches_per_step)
            if self.q2_network:
                q2_loss = F.mse_loss(q2_value, target_q_value)
                q2_loss.backward()
                self._maybe_run_optimizer(self.q2_network_optimizer,
                                          self.minibatches_per_step)

            #
            # Second, optimize the actor; minimizing KL-divergence between action propensity
            # & softmax of value. Due to reparameterization trick, it ends up being
            # log_prob(actor_action) - Q(s, actor_action)
            #

            state_actor_action = rlt.StateAction(
                state=state,
                action=rlt.FeatureVector(float_features=actor_output.action),
            )
            q1_actor_value = self.q1_network(state_actor_action).q_value
            min_q_actor_value = q1_actor_value
            if self.q2_network:
                q2_actor_value = self.q2_network(state_actor_action).q_value
                min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

            actor_loss = (self.entropy_temperature * actor_output.log_prob -
                          min_q_actor_value)
            # Do this in 2 steps so we can log histogram of actor loss
            actor_loss_mean = actor_loss.mean()
            actor_loss_mean.backward()
            self._maybe_run_optimizer(self.actor_network_optimizer,
                                      self.minibatches_per_step)

            #
            # Lastly, if applicable, optimize value network; minimizing MSE between
            # V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ]
            #

            if self.value_network is not None:
                state_value = self.value_network(state.float_features)

                if self.logged_action_uniform_prior:
                    log_prob_a = torch.zeros_like(min_q_actor_value)
                    target_value = min_q_actor_value
                else:
                    with torch.no_grad():
                        log_prob_a = actor_output.log_prob
                        log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                        target_value = (min_q_actor_value -
                                        self.entropy_temperature * log_prob_a)

                value_loss = F.mse_loss(state_value, target_value.detach())
                value_loss.backward()
                self._maybe_run_optimizer(self.value_network_optimizer,
                                          self.minibatches_per_step)

        # Use the soft update rule to update the target networks
        if self.value_network is not None:
            self._maybe_soft_update(
                self.value_network,
                self.value_network_target,
                self.tau,
                self.minibatches_per_step,
            )
        else:
            self._maybe_soft_update(
                self.q1_network,
                self.q1_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            if self.q2_network is not None:
                self._maybe_soft_update(
                    self.q2_network,
                    self.q2_network_target,
                    self.tau,
                    self.minibatches_per_step,
                )

        # Logging at the end to schedule all the cuda operations first
        if (self.tensorboard_logging_freq is not None
                and self.minibatch % self.tensorboard_logging_freq == 0):
            SummaryWriterContext.add_histogram("q1/logged_state_value",
                                               q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value",
                                                   q2_value)

            SummaryWriterContext.add_histogram("log_prob_a", log_prob_a)
            if self.value_network:
                SummaryWriterContext.add_histogram("value_network/target",
                                                   target_value)

            SummaryWriterContext.add_histogram("q_network/next_state_value",
                                               next_state_value)
            SummaryWriterContext.add_histogram("q_network/target_q_value",
                                               target_q_value)
            SummaryWriterContext.add_histogram("actor/min_q_actor_value",
                                               min_q_actor_value)
            SummaryWriterContext.add_histogram("actor/action_log_prob",
                                               actor_output.log_prob)
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
            model_propensities=actor_output.log_prob.exp(),
            model_values=min_q_actor_value,
        )
Exemplo n.º 17
0
    def train(self, training_samples: TrainingDataPage, evaluator=None) -> None:
        if self.minibatch == 0:
            # Assume that the tensors are the right shape after the first minibatch
            assert (
                training_samples.states.shape[0] == self.minibatch_size
            ), "Invalid shape: " + str(training_samples.states.shape)
            assert (
                training_samples.actions.shape[0] == self.minibatch_size
            ), "Invalid shape: " + str(training_samples.actions.shape)
            assert training_samples.rewards.shape == torch.Size(
                [self.minibatch_size, 1]
            ), "Invalid shape: " + str(training_samples.rewards.shape)
            assert (
                training_samples.next_states.shape == training_samples.states.shape
            ), "Invalid shape: " + str(training_samples.next_states.shape)
            assert (
                training_samples.not_terminals.shape == training_samples.rewards.shape
            ), "Invalid shape: " + str(training_samples.not_terminals.shape)
            if self.use_seq_num_diff_as_time_diff:
                assert (
                    training_samples.time_diffs.shape == training_samples.rewards.shape
                ), "Invalid shape: " + str(training_samples.time_diffs.shape)

        self.minibatch += 1
        states = training_samples.states.detach().requires_grad_(True)
        actions = training_samples.actions.detach().requires_grad_(True)

        # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh
        actions = rescale_torch_tensor(
            actions,
            new_min=self.min_action_range_tensor_training,
            new_max=self.max_action_range_tensor_training,
            prev_min=self.min_action_range_tensor_serving,
            prev_max=self.max_action_range_tensor_serving,
        )
        rewards = training_samples.rewards
        next_states = training_samples.next_states
        time_diffs = training_samples.time_diffs
        discount_tensor = torch.tensor(np.full(rewards.shape, self.gamma)).type(
            self.dtype
        )
        not_done_mask = training_samples.not_terminals

        # Optimize the critic network subject to mean squared error:
        # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2
        q_s1_a1 = self.critic(torch.cat((states, actions), dim=1))
        next_actions = self.actor_target(next_states)

        next_state_actions = torch.cat((next_states, next_actions), dim=1)
        q_s2_a2 = self.critic_target(next_state_actions)
        filtered_q_s2_a2 = not_done_mask * q_s2_a2

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(time_diffs)

        if self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            target_q_values = rewards + (discount_tensor * filtered_q_s2_a2)

        # compute loss and update the critic network
        critic_predictions = q_s1_a1
        loss_critic = self.q_network_loss(critic_predictions, target_q_values.detach())
        loss_critic_for_eval = loss_critic.detach()
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # Optimize the actor network subject to the following:
        # max mean(Q(s1, a1)) or min -mean(Q(s1, a1))
        loss_actor = -self.critic(torch.cat((states, self.actor(states)), dim=1)).mean()
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.actor, self.actor_target, 1.0)
            self._soft_update(self.critic, self.critic_target, 1.0)
        else:
            # Use the soft update rule to update both target networks
            self._soft_update(self.actor, self.actor_target, self.tau)
            self._soft_update(self.critic, self.critic_target, self.tau)

        if evaluator is not None:
            cpe_stats = BatchStatsForCPE(td_loss=loss_critic_for_eval.cpu().numpy())
            evaluator.report(cpe_stats)
Exemplo n.º 18
0
    def train(self, training_batch) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self._should_scale_action_in_train():
            action = rlt.FeatureVector(
                rescale_torch_tensor(
                    action.float_features,
                    new_min=self.min_action_range_tensor_training,
                    new_max=self.max_action_range_tensor_training,
                    prev_min=self.min_action_range_tensor_serving,
                    prev_max=self.max_action_range_tensor_serving,
                )
            )

        current_state_action = rlt.StateAction(state=state, action=action)

        q1_value = self.q1_network(current_state_action).q_value
        min_q_value = q1_value

        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value
            min_q_value = torch.min(q1_value, q2_value)

        # Use the minimum as target, ensure no gradient going through
        min_q_value = min_q_value.detach()

        #
        # First, optimize value network; minimizing MSE between
        # V(s) & Q(s, a) - log(pi(a|s))
        #

        state_value = self.value_network(state.float_features)  # .q_value

        if self.logged_action_uniform_prior:
            log_prob_a = torch.zeros_like(min_q_value)
            target_value = min_q_value
        else:
            with torch.no_grad():
                log_prob_a = self.actor_network.get_log_prob(
                    state, action.float_features
                )
                log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                target_value = min_q_value - self.entropy_temperature * log_prob_a

        value_loss = F.mse_loss(state_value, target_value)
        self.value_network_optimizer.zero_grad()
        value_loss.backward()
        self.value_network_optimizer.step()

        #
        # Second, optimize Q networks; minimizing MSE between
        # Q(s, a) & r + discount * V'(next_s)
        #

        with torch.no_grad():
            next_state_value = (
                self.value_network_target(learning_input.next_state.float_features)
                * not_done_mask.float()
            )

            if self.minibatch < self.reward_burnin:
                target_q_value = reward
            else:
                target_q_value = reward + discount * next_state_value

        q1_loss = F.mse_loss(q1_value, target_q_value)
        self.q1_network_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_network_optimizer.step()
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            self.q2_network_optimizer.zero_grad()
            q2_loss.backward()
            self.q2_network_optimizer.step()

        #
        # Lastly, optimize the actor; minimizing KL-divergence between action propensity
        # & softmax of value. Due to reparameterization trick, it ends up being
        # log_prob(actor_action) - Q(s, actor_action)
        #

        actor_output = self.actor_network(rlt.StateInput(state=state))

        state_actor_action = rlt.StateAction(
            state=state, action=rlt.FeatureVector(float_features=actor_output.action)
        )
        q1_actor_value = self.q1_network(state_actor_action).q_value
        min_q_actor_value = q1_actor_value
        if self.q2_network:
            q2_actor_value = self.q2_network(state_actor_action).q_value
            min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

        actor_loss = (
            self.entropy_temperature * actor_output.log_prob - min_q_actor_value
        )
        # Do this in 2 steps so we can log histogram of actor loss
        actor_loss_mean = actor_loss.mean()
        self.actor_network_optimizer.zero_grad()
        actor_loss_mean.backward()
        self.actor_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.value_network, self.value_network_target, 1.0)
        else:
            # Use the soft update rule to update both target networks
            self._soft_update(self.value_network, self.value_network_target, self.tau)

        # Logging at the end to schedule all the cuda operations first
        if (
            self.tensorboard_logging_freq is not None
            and self.minibatch % self.tensorboard_logging_freq == 0
        ):
            SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value)

            SummaryWriterContext.add_histogram("log_prob_a", log_prob_a)
            SummaryWriterContext.add_histogram("value_network/target", target_value)
            SummaryWriterContext.add_histogram(
                "q_network/next_state_value", next_state_value
            )
            SummaryWriterContext.add_histogram(
                "q_network/target_q_value", target_q_value
            )
            SummaryWriterContext.add_histogram(
                "actor/min_q_actor_value", min_q_actor_value
            )
            SummaryWriterContext.add_histogram(
                "actor/action_log_prob", actor_output.log_prob
            )
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
            model_propensities=actor_output.log_prob.exp(),
            model_values=min_q_actor_value,
        )