Ejemplo n.º 1
0
    def test_masked_softmax(self):
        # Postive value case
        x = torch.tensor([[15.0, 6.0, 9.0], [3.0, 2.0, 1.0]])
        temperature = 1
        mask = torch.tensor([[1.0, 0.0, 1.0], [0.0, 1.0, 1.0]])
        out = masked_softmax(x, mask, temperature)
        expected_out = torch.tensor([[0.9975, 0.0000, 0.0025],
                                     [0, 0.7311, 0.2689]])
        npt.assert_array_almost_equal(out, expected_out, 4)

        # Postive value case (masked value goes to inf)
        x = torch.tensor([[150.0, 2.0]])
        temperature = 0.01
        mask = torch.tensor([[0.0, 1.0]])
        out = masked_softmax(x, mask, temperature)
        expected_out = torch.tensor([[0.0, 1.0]])
        npt.assert_array_almost_equal(out, expected_out, 4)

        # Negative value case
        x = torch.tensor([[-10.0, -1.0, -5.0]])
        temperature = 0.01
        mask = torch.tensor([[1.0, 1.0, 0.0]])
        out = masked_softmax(x, mask, temperature)
        expected_out = torch.tensor([[0.0, 1.0, 0.0]])
        npt.assert_array_almost_equal(out, expected_out, 4)

        # All values in a row are masked case
        x = torch.tensor([[-5.0, 4.0, 3.0], [2.0, 1.0, 2.0]])
        temperature = 1
        mask = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])
        out = masked_softmax(x, mask, temperature)
        expected_out = torch.tensor([[0.0, 0.0, 0.0], [0.4223, 0.1554,
                                                       0.4223]])
        npt.assert_array_almost_equal(out, expected_out, 4)
Ejemplo n.º 2
0
    def _calculate_cpes(
        self,
        training_batch,
        states,
        next_states,
        all_action_scores,
        all_next_action_scores,
        logged_action_idxs,
        discount_tensor,
        not_done_mask,
    ):
        if not self.calc_cpe_in_training:
            return
        if training_batch.extras.metrics is None:
            metrics_reward_concat_real_vals = training_batch.reward
        else:
            metrics_reward_concat_real_vals = torch.cat(
                (training_batch.reward, training_batch.extras.metrics), dim=1)

        model_propensities_next_states = masked_softmax(
            all_next_action_scores,
            training_batch.possible_next_actions_mask
            if self.maxq_learning else training_batch.next_action,
            self.rl_temperature,
        )

        ######### Train separate reward network for CPE evaluation #############
        reward_estimates = self.reward_network(states)
        reward_estimates_for_logged_actions = reward_estimates.gather(
            1, self.reward_idx_offsets + logged_action_idxs)
        reward_loss = F.mse_loss(reward_estimates_for_logged_actions,
                                 metrics_reward_concat_real_vals)
        yield reward_loss

        ######### Train separate q-network for CPE evaluation #############
        metric_q_values = self.q_network_cpe(states).gather(
            1, self.reward_idx_offsets + logged_action_idxs)
        all_metrics_target_q_values = torch.chunk(
            self.q_network_cpe_target(next_states).detach(),
            len(self.metrics_to_score),
            dim=1,
        )
        target_metric_q_values = []
        for i, per_metric_target_q_values in enumerate(
                all_metrics_target_q_values):
            per_metric_next_q_values = torch.sum(
                per_metric_target_q_values * model_propensities_next_states,
                1,
                keepdim=True,
            )
            per_metric_next_q_values = per_metric_next_q_values * not_done_mask
            per_metric_target_q_values = metrics_reward_concat_real_vals[:, i:i + 1] + (
                discount_tensor * per_metric_next_q_values)
            target_metric_q_values.append(per_metric_target_q_values)

        target_metric_q_values = torch.cat(target_metric_q_values, dim=1)
        metric_q_value_loss = self.q_network_loss(metric_q_values,
                                                  target_metric_q_values)

        # The model_propensities computed below are not used right now. The CPE graphs in the Outputs
        # tab use model_propensities computed in the function create_from_tensors_dqn() in evaluation_data_page.py,
        # which is called on the eval_table_sample in the gather_eval_data() function below.
        model_propensities = masked_softmax(
            all_action_scores,
            training_batch.possible_actions_mask
            if self.maxq_learning else training_batch.action,
            self.rl_temperature,
        )
        # Extract rewards predicted by the reward_network. The other columns will
        # give predicted values for other metrics, if such were specified.
        model_rewards = reward_estimates[:,
                                         torch.arange(
                                             self.reward_idx_offsets[0],
                                             self.reward_idx_offsets[0] +
                                             self.num_actions,
                                         ), ]

        self.reporter.log(
            reward_loss=reward_loss,
            model_propensities=model_propensities,
            model_rewards=model_rewards,
        )

        yield metric_q_value_loss
Ejemplo n.º 3
0
    def _calculate_cpes(
        self,
        training_batch,
        states,
        next_states,
        all_action_scores,
        all_next_action_scores,
        logged_action_idxs,
        discount_tensor,
        not_done_mask,
    ):
        if not self.calc_cpe_in_training:
            return None, None, None

        if training_batch.extras.metrics is None:
            metrics_reward_concat_real_vals = training_batch.reward
        else:
            metrics_reward_concat_real_vals = torch.cat(
                (training_batch.reward, training_batch.extras.metrics), dim=1)

        model_propensities_next_states = masked_softmax(
            all_next_action_scores,
            training_batch.possible_next_actions_mask
            if self.maxq_learning else training_batch.next_action,
            self.rl_temperature,
        )

        with torch.enable_grad():
            ######### Train separate reward network for CPE evaluation #############
            reward_estimates = self.reward_network(states)
            reward_estimates_for_logged_actions = reward_estimates.gather(
                1, self.reward_idx_offsets + logged_action_idxs)
            reward_loss = F.mse_loss(reward_estimates_for_logged_actions,
                                     metrics_reward_concat_real_vals)
            reward_loss.backward()
            self._maybe_run_optimizer(self.reward_network_optimizer,
                                      self.minibatches_per_step)

            ######### Train separate q-network for CPE evaluation #############
            metric_q_values = self.q_network_cpe(states).gather(
                1, self.reward_idx_offsets + logged_action_idxs)
            all_metrics_target_q_values = torch.chunk(
                self.q_network_cpe_target(next_states).detach(),
                len(self.metrics_to_score),
                dim=1,
            )
            target_metric_q_values = []
            for i, per_metric_target_q_values in enumerate(
                    all_metrics_target_q_values):
                per_metric_next_q_values = torch.sum(
                    per_metric_target_q_values *
                    model_propensities_next_states,
                    1,
                    keepdim=True,
                )
                per_metric_next_q_values = per_metric_next_q_values * not_done_mask
                per_metric_target_q_values = metrics_reward_concat_real_vals[:, i:i + 1] + (
                    discount_tensor * per_metric_next_q_values)
                target_metric_q_values.append(per_metric_target_q_values)

            target_metric_q_values = torch.cat(target_metric_q_values, dim=1)
            metric_q_value_loss = self.q_network_loss(metric_q_values,
                                                      target_metric_q_values)
            metric_q_value_loss.backward()
            self._maybe_run_optimizer(self.q_network_cpe_optimizer,
                                      self.minibatches_per_step)

        # Use the soft update rule to update target network
        self._maybe_soft_update(
            self.q_network_cpe,
            self.q_network_cpe_target,
            self.tau,
            self.minibatches_per_step,
        )

        model_propensities = masked_softmax(
            all_action_scores,
            training_batch.possible_actions_mask
            if self.maxq_learning else training_batch.action,
            self.rl_temperature,
        )
        model_rewards = reward_estimates[:,
                                         torch.arange(
                                             self.reward_idx_offsets[0],
                                             self.reward_idx_offsets[0] +
                                             self.num_actions,
                                         ), ]
        return reward_loss, model_rewards, model_propensities
Ejemplo n.º 4
0
    def create_from_tensors_dqn(
        cls,
        trainer: DQNTrainer,
        mdp_ids: torch.Tensor,
        sequence_numbers: torch.Tensor,
        states: rlt.FeatureData,
        actions: rlt.FeatureData,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        metrics: Optional[torch.Tensor] = None,
    ):
        old_q_train_state = trainer.q_network.training
        # pyre-fixme[16]: `DQNTrainer` has no attribute `reward_network`.
        old_reward_train_state = trainer.reward_network.training
        # pyre-fixme[16]: `DQNTrainer` has no attribute `q_network_cpe`.
        old_q_cpe_train_state = trainer.q_network_cpe.training
        trainer.q_network.train(False)
        trainer.reward_network.train(False)
        trainer.q_network_cpe.train(False)

        num_actions = trainer.num_actions
        action_mask = actions.float()

        rewards = trainer.boost_rewards(rewards, actions)
        model_values = trainer.q_network_cpe(states)[:, 0:num_actions]
        optimal_q_values, _ = trainer.get_detached_q_values(states)
        # Do we ever really use eval_action_idxs?
        eval_action_idxs = trainer.get_max_q_values(optimal_q_values,
                                                    possible_actions_mask)[1]
        model_propensities = masked_softmax(optimal_q_values,
                                            possible_actions_mask,
                                            trainer.rl_temperature)
        assert model_values.shape == actions.shape, ("Invalid shape: " +
                                                     str(model_values.shape) +
                                                     " != " +
                                                     str(actions.shape))
        assert model_values.shape == possible_actions_mask.shape, (
            "Invalid shape: " + str(model_values.shape) + " != " +
            str(possible_actions_mask.shape))
        model_values_for_logged_action = torch.sum(model_values * action_mask,
                                                   dim=1,
                                                   keepdim=True)

        rewards_and_metric_rewards = trainer.reward_network(states)

        # In case we reuse the modular for Q-network
        if hasattr(rewards_and_metric_rewards, "q_values"):
            rewards_and_metric_rewards = rewards_and_metric_rewards

        model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
        assert model_rewards.shape == actions.shape, (
            "Invalid shape: " + str(model_rewards.shape) + " != " +
            str(actions.shape))
        model_rewards_for_logged_action = torch.sum(model_rewards *
                                                    action_mask,
                                                    dim=1,
                                                    keepdim=True)

        model_metrics = rewards_and_metric_rewards[:, num_actions:]

        assert model_metrics.shape[1] % num_actions == 0, (
            "Invalid metrics shape: " + str(model_metrics.shape) + " " +
            str(num_actions))
        num_metrics = model_metrics.shape[1] // num_actions

        if num_metrics == 0:
            model_metrics_values = None
            model_metrics_for_logged_action = None
            model_metrics_values_for_logged_action = None
        else:
            model_metrics_values = trainer.q_network_cpe(states)
            # Backward compatility
            if hasattr(model_metrics_values, "q_values"):
                model_metrics_values = model_metrics_values
            model_metrics_values = model_metrics_values[:, num_actions:]
            assert model_metrics_values.shape[
                1] == num_actions * num_metrics, (
                    "Invalid shape: " + str(model_metrics_values.shape[1]) +
                    " != " + str(actions.shape[1] * num_metrics))

            model_metrics_for_logged_action_list = []
            model_metrics_values_for_logged_action_list = []
            for metric_index in range(num_metrics):
                metric_start = metric_index * num_actions
                metric_end = (metric_index + 1) * num_actions
                model_metrics_for_logged_action_list.append(
                    torch.sum(
                        model_metrics[:, metric_start:metric_end] *
                        action_mask,
                        dim=1,
                        keepdim=True,
                    ))

                model_metrics_values_for_logged_action_list.append(
                    torch.sum(
                        model_metrics_values[:, metric_start:metric_end] *
                        action_mask,
                        dim=1,
                        keepdim=True,
                    ))
            model_metrics_for_logged_action = torch.cat(
                model_metrics_for_logged_action_list, dim=1)
            model_metrics_values_for_logged_action = torch.cat(
                model_metrics_values_for_logged_action_list, dim=1)

        trainer.q_network_cpe.train(old_q_cpe_train_state)
        trainer.q_network.train(old_q_train_state)
        trainer.reward_network.train(old_reward_train_state)

        return cls(
            mdp_id=mdp_ids,
            sequence_number=sequence_numbers,
            logged_propensities=propensities,
            logged_rewards=rewards,
            action_mask=action_mask,
            model_rewards=model_rewards,
            model_rewards_for_logged_action=model_rewards_for_logged_action,
            model_values=model_values,
            model_values_for_logged_action=model_values_for_logged_action,
            model_metrics_values=model_metrics_values,
            model_metrics_values_for_logged_action=
            model_metrics_values_for_logged_action,
            model_propensities=model_propensities,
            logged_metrics=metrics,
            model_metrics=model_metrics,
            model_metrics_for_logged_action=model_metrics_for_logged_action,
            # Will compute later
            logged_values=None,
            logged_metrics_values=None,
            possible_actions_mask=possible_actions_mask,
            optimal_q_values=optimal_q_values,
            eval_action_idxs=eval_action_idxs,
        )
Ejemplo n.º 5
0
    def create_from_tensors_parametric_dqn(
        cls,
        trainer: ParametricDQNTrainer,
        mdp_ids: torch.Tensor,
        sequence_numbers: torch.Tensor,
        states: rlt.FeatureData,
        actions: rlt.FeatureData,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        possible_actions: rlt.FeatureData,
        max_num_actions: int,
        metrics: Optional[torch.Tensor] = None,
    ):
        old_q_train_state = trainer.q_network.training
        old_reward_train_state = trainer.reward_network.training
        trainer.q_network.train(False)
        trainer.reward_network.train(False)

        tiled_state = states.float_features.repeat(1, max_num_actions).reshape(
            -1, states.float_features.shape[1])
        assert possible_actions is not None
        # Get Q-value of action taken
        possible_actions_state_concat = (rlt.FeatureData(tiled_state),
                                         possible_actions)

        # FIXME: model_values, model_values_for_logged_action, and model_metrics_values
        # should be calculated using q_network_cpe (as in discrete dqn).
        # q_network_cpe has not been added in parametric dqn yet.
        model_values = trainer.q_network(*possible_actions_state_concat)
        optimal_q_values, _ = trainer.get_detached_q_values(
            *possible_actions_state_concat)
        eval_action_idxs = None

        assert (model_values.shape[1] == 1
                and model_values.shape[0] == possible_actions_mask.shape[0] *
                possible_actions_mask.shape[1]), (
                    "Invalid shapes: " + str(model_values.shape) + " != " +
                    str(possible_actions_mask.shape))
        model_values = model_values.reshape(possible_actions_mask.shape)
        optimal_q_values = optimal_q_values.reshape(
            possible_actions_mask.shape)
        model_propensities = masked_softmax(optimal_q_values,
                                            possible_actions_mask,
                                            trainer.rl_temperature)

        rewards_and_metric_rewards = trainer.reward_network(
            *possible_actions_state_concat)
        model_rewards = rewards_and_metric_rewards[:, :1]
        assert (model_rewards.shape[0] *
                model_rewards.shape[1] == possible_actions_mask.shape[0] *
                possible_actions_mask.shape[1]), (
                    "Invalid shapes: " + str(model_rewards.shape) + " != " +
                    str(possible_actions_mask.shape))
        model_rewards = model_rewards.reshape(possible_actions_mask.shape)

        model_metrics = rewards_and_metric_rewards[:, 1:]
        model_metrics = model_metrics.reshape(possible_actions_mask.shape[0],
                                              -1)

        model_values_for_logged_action = trainer.q_network(states, actions)
        model_rewards_and_metrics_for_logged_action = trainer.reward_network(
            states, actions)
        model_rewards_for_logged_action = model_rewards_and_metrics_for_logged_action[:, :
                                                                                      1]

        action_dim = possible_actions.float_features.shape[1]
        action_mask = torch.all(
            possible_actions.float_features.view(
                -1, max_num_actions,
                action_dim) == actions.float_features.unsqueeze(dim=1),
            dim=2,
        ).float()
        assert torch.all(action_mask.sum(dim=1) == 1)
        num_metrics = model_metrics.shape[1] // max_num_actions

        model_metrics_values = None
        model_metrics_for_logged_action = None
        model_metrics_values_for_logged_action = None
        if num_metrics > 0:
            # FIXME: calculate model_metrics_values when q_network_cpe is added
            # to parametric dqn
            model_metrics_values = model_values.repeat(1, num_metrics)

        trainer.q_network.train(old_q_train_state)
        trainer.reward_network.train(old_reward_train_state)

        return cls(
            mdp_id=mdp_ids,
            sequence_number=sequence_numbers,
            logged_propensities=propensities,
            logged_rewards=rewards,
            action_mask=action_mask,
            model_rewards=model_rewards,
            model_rewards_for_logged_action=model_rewards_for_logged_action,
            model_values=model_values,
            model_values_for_logged_action=model_values_for_logged_action,
            model_metrics_values=model_metrics_values,
            model_metrics_values_for_logged_action=
            model_metrics_values_for_logged_action,
            model_propensities=model_propensities,
            logged_metrics=metrics,
            model_metrics=model_metrics,
            model_metrics_for_logged_action=model_metrics_for_logged_action,
            # Will compute later
            logged_values=None,
            logged_metrics_values=None,
            possible_actions_mask=possible_actions_mask,
            optimal_q_values=optimal_q_values,
            eval_action_idxs=eval_action_idxs,
        )