コード例 #1
0
 def _get_model(self, training_parameters, dueling_architecture=False):
     if dueling_architecture:
         return DuelingArchitectureQNetwork(
             training_parameters.layers,
             training_parameters.activations,
             action_dim=self.num_action_features,
         )
     elif training_parameters.factorization_parameters is None:
         return GenericFeedForwardNetwork(training_parameters.layers,
                                          training_parameters.activations)
     else:
         return ParametricInnerProduct(
             GenericFeedForwardNetwork(
                 training_parameters.factorization_parameters.state.layers,
                 training_parameters.factorization_parameters.state.
                 activations,
             ),
             GenericFeedForwardNetwork(
                 training_parameters.factorization_parameters.action.layers,
                 training_parameters.factorization_parameters.action.
                 activations,
             ),
             self.num_state_features,
             self.num_action_features,
         )
コード例 #2
0
    def __init__(
        self,
        parameters: DiscreteActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu=False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        gradient_handler=None,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self._actions = parameters.actions if parameters.actions is not None else []

        self.reward_shape = {}  # type: Dict[int, float]
        if parameters.rl.reward_boost is not None and self._actions is not None:
            for k in parameters.rl.reward_boost.keys():
                i = self._actions.index(k)
                self.reward_shape[i] = parameters.rl.reward_boost[k]

        if parameters.training.cnn_parameters is None:
            self.state_normalization_parameters: Optional[Dict[
                int, NormalizationParameters]] = state_normalization_parameters
            self.num_features = get_num_output_features(
                state_normalization_parameters)
            parameters.training.layers[0] = self.num_features
        else:
            self.state_normalization_parameters = None
        parameters.training.layers[-1] = self.num_actions

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types,
                           gradient_handler)

        if parameters.rainbow.dueling_architecture:
            self.q_network = DuelingArchitectureQNetwork(
                parameters.training.layers, parameters.training.activations)
        else:
            self.q_network = GenericFeedForwardNetwork(
                parameters.training.layers, parameters.training.activations)
        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(), lr=parameters.training.learning_rate)

        self.reward_network = GenericFeedForwardNetwork(
            parameters.training.layers, parameters.training.activations)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()
コード例 #3
0
    def __init__(
        self,
        parameters: ContinuousActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu=False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
    ) -> None:

        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters
        self.num_features = get_num_output_features(
            state_normalization_parameters) + get_num_output_features(
                action_normalization_parameters)

        # ensure state and action IDs have no intersection
        overlapping_features = set(
            state_normalization_parameters.keys()) & set(
                action_normalization_parameters.keys())
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: " +
            str(overlapping_features))

        parameters.training.layers[0] = self.num_features
        parameters.training.layers[-1] = 1

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types)

        self.q_network = GenericFeedForwardNetwork(
            parameters.training.layers, parameters.training.activations)
        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(), lr=parameters.training.learning_rate)

        self.reward_network = GenericFeedForwardNetwork(
            parameters.training.layers, parameters.training.activations)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()
コード例 #4
0
 def _get_model(self, training_parameters):
     if training_parameters.factorization_parameters is None:
         return GenericFeedForwardNetwork(training_parameters.layers,
                                          training_parameters.activations)
     else:
         return ParametricInnerProduct(
             GenericFeedForwardNetwork(
                 training_parameters.factorization_parameters.state.layers,
                 training_parameters.factorization_parameters.state.
                 activations,
             ),
             GenericFeedForwardNetwork(
                 training_parameters.factorization_parameters.action.layers,
                 training_parameters.factorization_parameters.action.
                 activations,
             ),
             self.num_state_features,
             self.num_action_features,
         )
コード例 #5
0
class DQNTrainer(RLTrainer):
    def __init__(
        self,
        parameters: DiscreteActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu=False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        gradient_handler=None,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self._actions = parameters.actions if parameters.actions is not None else []

        self.reward_shape = {}  # type: Dict[int, float]
        if parameters.rl.reward_boost is not None and self._actions is not None:
            for k in parameters.rl.reward_boost.keys():
                i = self._actions.index(k)
                self.reward_shape[i] = parameters.rl.reward_boost[k]

        if parameters.training.cnn_parameters is None:
            self.state_normalization_parameters: Optional[Dict[
                int, NormalizationParameters]] = state_normalization_parameters
            self.num_features = get_num_output_features(
                state_normalization_parameters)
            parameters.training.layers[0] = self.num_features
        else:
            self.state_normalization_parameters = None
        parameters.training.layers[-1] = self.num_actions

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types,
                           gradient_handler)

        if parameters.rainbow.dueling_architecture:
            self.q_network = DuelingArchitectureQNetwork(
                parameters.training.layers, parameters.training.activations)
        else:
            self.q_network = GenericFeedForwardNetwork(
                parameters.training.layers, parameters.training.activations)
        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(), lr=parameters.training.learning_rate)

        self.reward_network = GenericFeedForwardNetwork(
            parameters.training.layers, parameters.training.activations)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()

    @property
    def num_actions(self) -> int:
        return len(self._actions)

    def calculate_q_values(self, states):
        is_numpy = False
        if isinstance(states, np.ndarray):
            is_numpy = True
            states = torch.tensor(states).type(self.dtype)
        result = self.q_network(states).detach()
        if is_numpy:
            return result.cpu().numpy()
        else:
            return result

    def get_max_q_values(self, states, possible_actions, double_q_learning):
        """
        Used in Q-learning update.
        :param states: Numpy array with shape (batch_size, state_dim). Each row
            contains a representation of a state.
        :param possible_actions: Numpy array with shape (batch_size, action_dim).
            possible_next_actions[i][j] = 1 iff the agent can take action j from
            state i.
        :param double_q_learning: bool to use double q-learning
        """
        if double_q_learning:
            q_values = self.q_network(states).detach()
            q_values_target = self.q_network_target(states).detach()
            # Set q-values of impossible actions to a very large negative number.
            inverse_pna = 1 - possible_actions
            impossible_action_penalty = self.ACTION_NOT_POSSIBLE_VAL * inverse_pna
            q_values += impossible_action_penalty
            # Select max_q action after scoring with online network
            max_q_values, max_indicies = torch.max(q_values, 1)
            # Use q_values from target network for max_q action from online q_network
            # to decouble selection & scoring, preventing overestimation of q-values
            q_values = torch.gather(q_values_target, 1,
                                    max_indicies.unsqueeze(1))
            return Variable(q_values.squeeze())
        else:
            q_values = self.q_network_target(states).detach()
            # Set q-values of impossible actions to a very large negative number.
            inverse_pna = 1 - possible_actions
            impossible_action_penalty = self.ACTION_NOT_POSSIBLE_VAL * inverse_pna
            q_values += impossible_action_penalty
            return Variable(torch.max(q_values, 1)[0])

    def get_next_action_q_values(self, states, next_actions):
        """
        Used in SARSA update.
        :param states: Numpy array with shape (batch_size, state_dim). Each row
            contains a representation of a state.
        :param next_actions: Numpy array with shape (batch_size, action_dim).
        """
        q_values = self.q_network_target(states).detach()
        return Variable(torch.sum(q_values * next_actions, 1))

    def train(self,
              training_samples: TrainingDataPage,
              evaluator: Optional[Evaluator] = None) -> None:

        # Apply reward boost if specified
        boosted_rewards = np.copy(training_samples.rewards)
        if len(self.reward_shape) > 0:
            boost_idxs = np.argmax(training_samples.actions, 1)
            boosts = np.array([self.reward_shape[x] for x in boost_idxs])
            boosted_rewards += boosts

        self.minibatch += 1
        if isinstance(training_samples.states, torch.Tensor):
            states = training_samples.states.type(self.dtype)
        else:
            states = torch.from_numpy(training_samples.states).type(self.dtype)
        states = Variable(states)
        actions = Variable(
            torch.from_numpy(training_samples.actions).type(self.dtype))
        rewards = Variable(torch.from_numpy(boosted_rewards).type(self.dtype))
        if isinstance(training_samples.next_states, torch.Tensor):
            next_states = training_samples.next_states.type(self.dtype)
        else:
            next_states = torch.from_numpy(training_samples.next_states).type(
                self.dtype)
        next_states = Variable(next_states)
        time_diffs = torch.tensor(training_samples.time_diffs).type(self.dtype)
        discount_tensor = torch.tensor(np.full(len(rewards),
                                               self.gamma)).type(self.dtype)
        not_done_mask = Variable(
            torch.from_numpy(training_samples.not_terminals.astype(int))).type(
                self.dtype)

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(time_diffs)

        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            possible_next_actions = Variable(
                torch.from_numpy(training_samples.possible_next_actions).type(
                    self.dtype))
            next_q_values = self.get_max_q_values(next_states,
                                                  possible_next_actions,
                                                  self.double_q_learning)
        else:
            # SARSA
            next_actions = Variable(
                torch.from_numpy(training_samples.next_actions).type(
                    self.dtype))
            next_q_values = self.get_next_action_q_values(
                next_states, next_actions)

        filtered_next_q_vals = next_q_values * not_done_mask

        if self.use_reward_burnin and self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            target_q_values = rewards + (discount_tensor *
                                         filtered_next_q_vals)

        # Get Q-value of action taken
        all_q_values = self.q_network(states)
        self.all_action_scores = deepcopy(all_q_values.detach())
        q_values = torch.sum(all_q_values * actions, 1)

        loss = self.q_network_loss(q_values, target_q_values)
        self.loss = loss.detach()

        self.q_network_optimizer.zero_grad()
        loss.backward()
        self.q_network_optimizer.step()

        if self.use_reward_burnin and self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        # get reward estimates
        reward_estimates = (self.reward_network(states).gather(
            1,
            actions.argmax(1).unsqueeze(1)).squeeze())
        reward_loss = F.mse_loss(reward_estimates, rewards)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        if evaluator is not None:
            self.evaluate(
                evaluator,
                training_samples.actions,
                training_samples.propensities,
                np.expand_dims(boosted_rewards, axis=1),
                training_samples.episode_values,
            )

    def evaluate(
        self,
        evaluator: Evaluator,
        logged_actions: Optional[np.ndarray],
        logged_propensities: Optional[np.ndarray],
        logged_rewards: Optional[np.ndarray],
        logged_values: Optional[np.ndarray],
    ):
        self.model_propensities, model_values_on_logged_actions, maxq_action_idxs = (
            None,
            None,
            None,
        )
        if self.all_action_scores is not None:
            self.all_action_scores = self.all_action_scores.cpu().numpy()
            self.model_propensities = Evaluator.softmax(
                self.all_action_scores, self.rl_temperature)
            maxq_action_idxs = self.all_action_scores.argmax(axis=1)
            if logged_actions is not None:
                model_values_on_logged_actions = np.sum(
                    (logged_actions * self.all_action_scores),
                    axis=1,
                    keepdims=True)

        evaluator.report(
            self.loss.cpu().numpy(),
            logged_actions,
            logged_propensities,
            logged_rewards,
            logged_values,
            self.model_propensities,
            self.all_action_scores,
            model_values_on_logged_actions,
            maxq_action_idxs,
        )

    def predictor(self) -> DQNPredictor:
        """Builds a DQNPredictor."""
        return DQNPredictor.export(
            self,
            self._actions,
            self.state_normalization_parameters,
            self._additional_feature_types.int_features,
            self.use_gpu,
        )
コード例 #6
0
class ParametricDQNTrainer(RLTrainer):
    def __init__(
        self,
        parameters: ContinuousActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu=False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
    ) -> None:

        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters
        self.num_features = get_num_output_features(
            state_normalization_parameters) + get_num_output_features(
                action_normalization_parameters)

        # ensure state and action IDs have no intersection
        overlapping_features = set(
            state_normalization_parameters.keys()) & set(
                action_normalization_parameters.keys())
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: " +
            str(overlapping_features))

        parameters.training.layers[0] = self.num_features
        parameters.training.layers[-1] = 1

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types)

        self.q_network = GenericFeedForwardNetwork(
            parameters.training.layers, parameters.training.activations)
        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(), lr=parameters.training.learning_rate)

        self.reward_network = GenericFeedForwardNetwork(
            parameters.training.layers, parameters.training.activations)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()

    def get_max_q_values(self, next_state_pnas_concat,
                         possible_actions_lengths):
        """
        :param next_state_pnas_concat: Numpy array with shape
            (sum(possible_actions_lengths), state_dim + action_dim). Each row
            contains a representation of a state + possible next action pair.
        :param possible_actions_lengths: Numpy array that describes number of
            possible_actions per item in minibatch
        """
        q_network_input = torch.from_numpy(next_state_pnas_concat).type(
            self.dtype)
        q_values = self.q_network_target(q_network_input).detach()

        pnas_lens = torch.from_numpy(possible_actions_lengths).type(self.dtype)
        pna_len_cumsum = pnas_lens.cumsum(0)
        zero_first_cumsum = torch.cat((torch.zeros(1).type(self.dtype),
                                       pna_len_cumsum)).type(self.dtypelong)
        idxs = torch.arange(0, q_values.size(0)).type(self.dtypelong)

        # Hacky way to do Caffe2's LengthsMax in PyTorch.
        bag = WeightedEmbeddingBag(q_values.unsqueeze(1), mode="max")
        max_q_values = bag(idxs, zero_first_cumsum).squeeze().detach()

        # EmbeddingBag adds a 0 entry to the end of the tensor so slice it
        return Variable(max_q_values[:-1])

    def get_next_action_q_values(self, states, next_actions):
        """
        :param states: Numpy array with shape (batch_size, state_dim). Each row
            contains a representation of a state.
        :param next_actions: Numpy array with shape (batch_size, state_dim). Each row
            contains a representation of an action.
        """
        q_network_input = np.concatenate([states, next_actions], 1)
        q_network_input = torch.from_numpy(q_network_input).type(self.dtype)
        return Variable(
            self.q_network_target(q_network_input).detach().squeeze())

    def train(self,
              training_samples: TrainingDataPage,
              evaluator=None,
              episode_values=None) -> None:

        self.minibatch += 1
        states = Variable(
            torch.from_numpy(training_samples.states).type(self.dtype))
        actions = Variable(
            torch.from_numpy(training_samples.actions).type(self.dtype))
        state_action_pairs = torch.cat((states, actions), dim=1)
        rewards = Variable(
            torch.from_numpy(training_samples.rewards).type(self.dtype))
        time_diffs = torch.tensor(training_samples.time_diffs).type(self.dtype)
        discount_tensor = torch.tensor(np.full(len(rewards),
                                               self.gamma)).type(self.dtype)
        not_done_mask = Variable(
            torch.from_numpy(training_samples.not_terminals.astype(int))).type(
                self.dtype)

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(time_diffs)

        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values = self.get_max_q_values(
                training_samples.next_state_pnas_concat,
                training_samples.possible_next_actions_lengths,
            )
        else:
            # SARSA
            next_q_values = self.get_next_action_q_values(
                training_samples.next_states, training_samples.next_actions)

        filtered_max_q_vals = next_q_values * not_done_mask

        if self.minibatch >= self.reward_burnin:
            target_q_values = rewards + (discount_tensor * filtered_max_q_vals)
        else:
            target_q_values = rewards

        # Get Q-value of action taken
        q_values = self.q_network(state_action_pairs)
        self.all_action_scores = deepcopy(q_values.detach())

        value_loss = F.mse_loss(q_values.squeeze(), target_q_values)
        self.loss = value_loss.detach()

        self.q_network_optimizer.zero_grad()
        value_loss.backward()
        self.q_network_optimizer.step()

        if self.minibatch >= self.reward_burnin:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)
        else:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)

        # get reward estimates
        reward_estimates = self.reward_network(state_action_pairs).squeeze()
        reward_loss = F.mse_loss(reward_estimates, rewards)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        if evaluator is not None:
            self.evaluate(
                evaluator,
                training_samples.actions,
                training_samples.propensities,
                training_samples.episode_values,
            )

    def evaluate(
        self,
        evaluator: Evaluator,
        logged_actions: Optional[np.ndarray],
        logged_propensities: Optional[np.ndarray],
        logged_values: Optional[np.ndarray],
    ):
        evaluator.report(
            self.loss.cpu().numpy(),
            None,
            None,
            None,
            logged_values,
            None,
            None,
            self.all_action_scores.cpu().numpy(),
            None,
        )

    def predictor(self) -> ParametricDQNPredictor:
        """Builds a ParametricDQNPredictor."""
        return ParametricDQNPredictor.export(
            self,
            self.state_normalization_parameters,
            self.action_normalization_parameters,
            self._additional_feature_types.int_features,
            self.use_gpu,
        )
コード例 #7
0
class ParametricDQNTrainer(RLTrainer):
    def __init__(
        self,
        parameters: ContinuousActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu=False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters
        self.num_state_features = get_num_output_features(
            state_normalization_parameters)
        self.num_action_features = get_num_output_features(
            action_normalization_parameters)
        self.num_features = self.num_state_features + self.num_action_features

        # ensure state and action IDs have no intersection
        overlapping_features = set(
            state_normalization_parameters.keys()) & set(
                action_normalization_parameters.keys())
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: " +
            str(overlapping_features))

        reward_network_layers = deepcopy(parameters.training.layers)
        reward_network_layers[0] = self.num_features
        reward_network_layers[-1] = 1

        if parameters.rainbow.dueling_architecture:
            parameters.training.layers[0] = self.num_state_features
            parameters.training.layers[-1] = 1
        elif parameters.training.factorization_parameters is None:
            parameters.training.layers[0] = self.num_features
            parameters.training.layers[-1] = 1
        else:
            parameters.training.factorization_parameters.state.layers[
                0] = self.num_state_features
            parameters.training.factorization_parameters.action.layers[
                0] = self.num_action_features

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types,
                           None)

        self.q_network = self._get_model(
            parameters.training, parameters.rainbow.dueling_architecture)

        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(), lr=parameters.training.learning_rate)

        self.reward_network = GenericFeedForwardNetwork(
            reward_network_layers, parameters.training.activations)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()

    def _get_model(self, training_parameters, dueling_architecture=False):
        if dueling_architecture:
            return DuelingArchitectureQNetwork(
                training_parameters.layers,
                training_parameters.activations,
                action_dim=self.num_action_features,
            )
        elif training_parameters.factorization_parameters is None:
            return GenericFeedForwardNetwork(training_parameters.layers,
                                             training_parameters.activations)
        else:
            return ParametricInnerProduct(
                GenericFeedForwardNetwork(
                    training_parameters.factorization_parameters.state.layers,
                    training_parameters.factorization_parameters.state.
                    activations,
                ),
                GenericFeedForwardNetwork(
                    training_parameters.factorization_parameters.action.layers,
                    training_parameters.factorization_parameters.action.
                    activations,
                ),
                self.num_state_features,
                self.num_action_features,
            )

    def calculate_q_values(self, state_pas_concats, pas_lens):
        row_nums = np.arange(len(pas_lens))
        row_idxs = np.repeat(row_nums, pas_lens)
        col_idxs = arange_expand(pas_lens)

        dense_idxs = torch.LongTensor(
            (row_idxs, col_idxs)).type(self.dtypelong)
        q_network_input = torch.from_numpy(state_pas_concats).type(self.dtype)

        q_values = self.q_network(q_network_input).detach().squeeze()

        dense_dim = [len(pas_lens), max(pas_lens)]
        # Add specific fingerprint to q-values so that after sparse -> dense we can
        # subtract the fingerprint to identify the 0's added in sparse -> dense
        q_values.add_(self.FINGERPRINT)
        sparse_q = torch.sparse_coo_tensor(dense_idxs, q_values, dense_dim)
        dense_q = sparse_q.to_dense()
        dense_q.add_(self.FINGERPRINT * -1)
        dense_q[dense_q == self.FINGERPRINT *
                -1] = self.ACTION_NOT_POSSIBLE_VAL

        return dense_q.cpu().numpy()

    def get_max_q_values(self, next_state_pnas_concat, pnas_lens,
                         double_q_learning):
        """
        :param next_state_pnas_concat: Numpy array with shape
            (sum(pnas_lens), state_dim + action_dim). Each row
            contains a representation of a state + possible next action pair.
        :param pnas_lens: Numpy array that describes number of
            possible_actions per item in minibatch
        :param double_q_learning: bool to use double q-learning
        """
        row_nums = np.arange(len(pnas_lens))
        row_idxs = np.repeat(row_nums, pnas_lens)
        col_idxs = arange_expand(pnas_lens)

        dense_idxs = torch.LongTensor(
            (row_idxs, col_idxs)).type(self.dtypelong)
        q_network_input = torch.from_numpy(next_state_pnas_concat).type(
            self.dtype)

        if double_q_learning:
            q_values = self.q_network(q_network_input).detach().squeeze()
            q_values_target = self.q_network_target(
                q_network_input).detach().squeeze()
        else:
            q_values = self.q_network_target(
                q_network_input).detach().squeeze()

        dense_dim = [len(pnas_lens), max(pnas_lens)]
        # Add specific fingerprint to q-values so that after sparse -> dense we can
        # subtract the fingerprint to identify the 0's added in sparse -> dense
        q_values.add_(self.FINGERPRINT)
        sparse_q = torch.sparse_coo_tensor(dense_idxs, q_values, dense_dim)
        dense_q = sparse_q.to_dense()
        dense_q.add_(self.FINGERPRINT * -1)
        dense_q[dense_q == self.FINGERPRINT *
                -1] = self.ACTION_NOT_POSSIBLE_VAL
        max_q_values, max_indexes = torch.max(dense_q, dim=1)

        if double_q_learning:
            sparse_q_target = torch.sparse_coo_tensor(dense_idxs,
                                                      q_values_target,
                                                      dense_dim)
            dense_q_values_target = sparse_q_target.to_dense()
            max_q_values = torch.gather(dense_q_values_target, 1,
                                        max_indexes.unsqueeze(1))

        return Variable(max_q_values.squeeze())

    def get_next_action_q_values(self, states, next_actions):
        """
        :param states: Numpy array with shape (batch_size, state_dim). Each row
            contains a representation of a state.
        :param next_actions: Numpy array with shape (batch_size, state_dim). Each row
            contains a representation of an action.
        """
        q_network_input = np.concatenate([states, next_actions], 1)
        q_network_input = torch.from_numpy(q_network_input).type(self.dtype)
        return Variable(
            self.q_network_target(q_network_input).detach().squeeze())

    def train(self,
              training_samples: TrainingDataPage,
              evaluator=None,
              episode_values=None) -> None:

        self.minibatch += 1
        states = Variable(
            torch.from_numpy(training_samples.states).type(self.dtype))
        actions = Variable(
            torch.from_numpy(training_samples.actions).type(self.dtype))
        state_action_pairs = torch.cat((states, actions), dim=1)
        rewards = Variable(
            torch.from_numpy(training_samples.rewards).type(self.dtype))
        time_diffs = torch.tensor(training_samples.time_diffs).type(self.dtype)
        discount_tensor = torch.tensor(np.full(len(rewards),
                                               self.gamma)).type(self.dtype)
        not_done_mask = Variable(
            torch.from_numpy(training_samples.not_terminals.astype(int))).type(
                self.dtype)

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(time_diffs)

        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values = self.get_max_q_values(
                training_samples.next_state_pnas_concat,
                training_samples.possible_next_actions_lengths,
                self.double_q_learning,
            )
        else:
            # SARSA
            next_q_values = self.get_next_action_q_values(
                training_samples.next_states, training_samples.next_actions)

        filtered_max_q_vals = next_q_values * not_done_mask

        if self.use_reward_burnin and self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            target_q_values = rewards + (discount_tensor * filtered_max_q_vals)

        # Get Q-value of action taken
        q_values = self.q_network(state_action_pairs)
        self.all_action_scores = deepcopy(q_values.detach())

        value_loss = self.q_network_loss(q_values.squeeze(), target_q_values)
        self.loss = value_loss.detach()

        self.q_network_optimizer.zero_grad()
        value_loss.backward()
        self.q_network_optimizer.step()

        if self.use_reward_burnin and self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        # get reward estimates
        reward_estimates = self.reward_network(state_action_pairs).squeeze()
        reward_loss = F.mse_loss(reward_estimates, rewards)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        if evaluator is not None:
            self.evaluate(
                evaluator,
                training_samples.actions,
                training_samples.propensities,
                training_samples.episode_values,
            )

    def evaluate(
        self,
        evaluator: Evaluator,
        logged_actions: Optional[np.ndarray],
        logged_propensities: Optional[np.ndarray],
        logged_values: Optional[np.ndarray],
    ):
        evaluator.report(
            self.loss.cpu().numpy(),
            None,
            None,
            None,
            logged_values,
            None,
            None,
            self.all_action_scores.cpu().numpy(),
            None,
        )

    def predictor(self) -> ParametricDQNPredictor:
        """Builds a ParametricDQNPredictor."""
        return ParametricDQNPredictor.export(
            self,
            self.state_normalization_parameters,
            self.action_normalization_parameters,
            self._additional_feature_types.int_features,
            self.use_gpu,
        )