Example #1
0
    def _init_cpe_networks(self, parameters, use_all_avail_gpus):
        if not self.calc_cpe_in_training:
            return

        reward_network_layers = deepcopy(parameters.training.layers)
        if self.metrics_to_score:
            num_output_nodes = len(self.metrics_to_score) * self.num_actions
        else:
            num_output_nodes = self.num_actions

        reward_network_layers[-1] = num_output_nodes
        self.reward_idx_offsets = torch.arange(
            0, num_output_nodes, self.num_actions
        ).type(self.dtypelong)
        logger.info(
            "Reward network for CPE will have {} output nodes.".format(num_output_nodes)
        )

        if parameters.training.cnn_parameters is None:
            self.reward_network = FullyConnectedNetwork(
                reward_network_layers, parameters.training.activations
            )
            self.q_network_cpe = FullyConnectedNetwork(
                reward_network_layers, parameters.training.activations
            )
        else:
            self.reward_network = ConvolutionalNetwork(
                parameters.training.cnn_parameters,
                reward_network_layers,
                parameters.training.activations,
            )
            self.q_network_cpe = ConvolutionalNetwork(
                parameters.training.cnn_parameters,
                reward_network_layers,
                parameters.training.activations,
            )
        self.q_network_cpe_target = deepcopy(self.q_network_cpe)
        self.q_network_cpe_optimizer = self.optimizer_func(
            self.q_network_cpe.parameters(), lr=parameters.training.learning_rate
        )
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(), lr=parameters.training.learning_rate
        )
        if self.use_gpu:
            self.reward_network.cuda()
            self.q_network_cpe.cuda()
            self.q_network_cpe_target.cuda()
            if use_all_avail_gpus:
                self.reward_network = torch.nn.DataParallel(self.reward_network)
                self.q_network_cpe = torch.nn.DataParallel(self.q_network_cpe)
                self.q_network_cpe_target = torch.nn.DataParallel(
                    self.q_network_cpe_target
                )
Example #2
0
    def __init__(
        self,
        parameters: DiscreteActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        metrics_to_score=None,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self._actions = parameters.actions if parameters.actions is not None else []

        if parameters.training.cnn_parameters is None:
            self.state_normalization_parameters: Optional[Dict[
                int, NormalizationParameters]] = state_normalization_parameters
            self.num_features = get_num_output_features(
                state_normalization_parameters)
            logger.info("Number of state features: " + str(self.num_features))
            parameters.training.layers[0] = self.num_features
        else:
            self.state_normalization_parameters = None
        parameters.training.layers[-1] = self.num_actions

        RLTrainer.__init__(
            self,
            parameters,
            use_gpu,
            additional_feature_types,
            metrics_to_score,
            gradient_handler,
            actions=self._actions,
        )

        self.reward_boosts = torch.zeros([1,
                                          len(self._actions)]).type(self.dtype)
        if parameters.rl.reward_boost is not None:
            for k in parameters.rl.reward_boost.keys():
                i = self._actions.index(k)
                self.reward_boosts[0, i] = parameters.rl.reward_boost[k]

        if parameters.rainbow.dueling_architecture:
            self.q_network = DuelingQNetwork(
                parameters.training.layers,
                parameters.training.activations,
                use_batch_norm=parameters.training.use_batch_norm,
            )
        else:
            if parameters.training.cnn_parameters is None:
                self.q_network = FullyConnectedNetwork(
                    parameters.training.layers,
                    parameters.training.activations,
                    use_noisy_linear_layers=parameters.training.
                    use_noisy_linear_layers,
                    min_std=parameters.training.weight_init_min_std,
                    use_batch_norm=parameters.training.use_batch_norm,
                )
            else:
                self.q_network = ConvolutionalNetwork(
                    parameters.training.cnn_parameters,
                    parameters.training.layers,
                    parameters.training.activations,
                    use_noisy_linear_layers=parameters.training.
                    use_noisy_linear_layers,
                    min_std=parameters.training.weight_init_min_std,
                    use_batch_norm=parameters.training.use_batch_norm,
                )

        self.q_network_target = deepcopy(self.q_network)
        self.q_network._name = "training"
        self.q_network_target._name = "target"
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        self._init_cpe_networks(parameters, use_all_avail_gpus)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(
                    self.q_network_target)
Example #3
0
class DQNTrainer(DQNTrainerBase):
    def __init__(
        self,
        parameters: DiscreteActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        metrics_to_score=None,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self._actions = parameters.actions if parameters.actions is not None else []

        if parameters.training.cnn_parameters is None:
            self.state_normalization_parameters: Optional[Dict[
                int, NormalizationParameters]] = state_normalization_parameters
            self.num_features = get_num_output_features(
                state_normalization_parameters)
            logger.info("Number of state features: " + str(self.num_features))
            parameters.training.layers[0] = self.num_features
        else:
            self.state_normalization_parameters = None
        parameters.training.layers[-1] = self.num_actions

        RLTrainer.__init__(
            self,
            parameters,
            use_gpu,
            additional_feature_types,
            metrics_to_score,
            gradient_handler,
            actions=self._actions,
        )

        self.reward_boosts = torch.zeros([1,
                                          len(self._actions)]).type(self.dtype)
        if parameters.rl.reward_boost is not None:
            for k in parameters.rl.reward_boost.keys():
                i = self._actions.index(k)
                self.reward_boosts[0, i] = parameters.rl.reward_boost[k]

        if parameters.rainbow.dueling_architecture:
            self.q_network = DuelingQNetwork(
                parameters.training.layers,
                parameters.training.activations,
                use_batch_norm=parameters.training.use_batch_norm,
            )
        else:
            if parameters.training.cnn_parameters is None:
                self.q_network = FullyConnectedNetwork(
                    parameters.training.layers,
                    parameters.training.activations,
                    use_noisy_linear_layers=parameters.training.
                    use_noisy_linear_layers,
                    min_std=parameters.training.weight_init_min_std,
                    use_batch_norm=parameters.training.use_batch_norm,
                )
            else:
                self.q_network = ConvolutionalNetwork(
                    parameters.training.cnn_parameters,
                    parameters.training.layers,
                    parameters.training.activations,
                    use_noisy_linear_layers=parameters.training.
                    use_noisy_linear_layers,
                    min_std=parameters.training.weight_init_min_std,
                    use_batch_norm=parameters.training.use_batch_norm,
                )

        self.q_network_target = deepcopy(self.q_network)
        self.q_network._name = "training"
        self.q_network_target._name = "target"
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        self._init_cpe_networks(parameters, use_all_avail_gpus)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(
                    self.q_network_target)

    def _init_cpe_networks(self, parameters, use_all_avail_gpus):
        if not self.calc_cpe_in_training:
            return

        reward_network_layers = deepcopy(parameters.training.layers)
        if self.metrics_to_score:
            num_output_nodes = len(self.metrics_to_score) * self.num_actions
        else:
            num_output_nodes = self.num_actions

        reward_network_layers[-1] = num_output_nodes
        self.reward_idx_offsets = torch.arange(
            0, num_output_nodes, self.num_actions).type(self.dtypelong)
        logger.info("Reward network for CPE will have {} output nodes.".format(
            num_output_nodes))

        if parameters.training.cnn_parameters is None:
            self.reward_network = FullyConnectedNetwork(
                reward_network_layers, parameters.training.activations)
            self.q_network_cpe = FullyConnectedNetwork(
                reward_network_layers, parameters.training.activations)
        else:
            self.reward_network = ConvolutionalNetwork(
                parameters.training.cnn_parameters,
                reward_network_layers,
                parameters.training.activations,
            )
            self.q_network_cpe = ConvolutionalNetwork(
                parameters.training.cnn_parameters,
                reward_network_layers,
                parameters.training.activations,
            )
        self.q_network_cpe_target = deepcopy(self.q_network_cpe)
        self.q_network_cpe_optimizer = self.optimizer_func(
            self.q_network_cpe.parameters(),
            lr=parameters.training.learning_rate)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)
        if self.use_gpu:
            self.reward_network.cuda()
            self.q_network_cpe.cuda()
            self.q_network_cpe_target.cuda()
            if use_all_avail_gpus:
                self.reward_network = torch.nn.DataParallel(
                    self.reward_network)
                self.q_network_cpe = torch.nn.DataParallel(self.q_network_cpe)
                self.q_network_cpe_target = torch.nn.DataParallel(
                    self.q_network_cpe_target)

    @property
    def num_actions(self) -> int:
        return len(self._actions)

    def calculate_q_values(self, states):
        return self.q_network(states).detach()

    def calculate_metric_q_values(self, states):
        return self.q_network_cpe(states).detach()

    def get_detached_q_values(self,
                              states) -> Tuple[torch.Tensor, torch.Tensor]:
        with torch.no_grad():
            q_values = self.q_network(states)
            q_values_target = self.q_network_target(states)
        return q_values, q_values_target

    def get_next_action_q_values(self, states, next_actions):
        """
        Used in SARSA update.
        :param states: Numpy array with shape (batch_size, state_dim). Each row
            contains a representation of a state.
        :param next_actions: Numpy array with shape (batch_size, action_dim).
        """
        q_values = self.q_network_target(states).detach()
        # Max-q action indexes used in CPE
        max_q_values, max_indicies = torch.max(q_values, dim=1, keepdim=True)
        return (torch.sum(q_values * next_actions, dim=1,
                          keepdim=True), max_indicies)

    def train(self, training_samples: TrainingDataPage):

        if self.minibatch == 0:
            # Assume that the tensors are the right shape after the first minibatch
            assert (training_samples.states.shape[0] == self.minibatch_size
                    ), "Invalid shape: " + str(training_samples.states.shape)
            assert training_samples.actions.shape == torch.Size([
                self.minibatch_size, len(self._actions)
            ]), "Invalid shape: " + str(training_samples.actions.shape)
            assert training_samples.rewards.shape == torch.Size(
                [self.minibatch_size,
                 1]), "Invalid shape: " + str(training_samples.rewards.shape)
            assert (training_samples.next_states.shape ==
                    training_samples.states.shape), "Invalid shape: " + str(
                        training_samples.next_states.shape)
            assert (training_samples.not_terminal.shape ==
                    training_samples.rewards.shape), "Invalid shape: " + str(
                        training_samples.not_terminal.shape)
            if training_samples.possible_next_actions_mask is not None:
                assert (
                    training_samples.possible_next_actions_mask.shape ==
                    training_samples.actions.shape), (
                        "Invalid shape: " +
                        str(training_samples.possible_next_actions_mask.shape))
            if training_samples.propensities is not None:
                assert (training_samples.propensities.shape == training_samples
                        .rewards.shape), "Invalid shape: " + str(
                            training_samples.propensities.shape)
            if training_samples.metrics is not None:
                assert (
                    training_samples.metrics.shape[0] == self.minibatch_size
                ), "Invalid shape: " + str(training_samples.metrics.shape)

        boosted_rewards = self.boost_rewards(training_samples.rewards,
                                             training_samples.actions)

        self.minibatch += 1
        states = training_samples.states.detach().requires_grad_(True)
        actions = training_samples.actions
        rewards = boosted_rewards
        discount_tensor = torch.full(training_samples.time_diffs.shape,
                                     self.gamma).type(self.dtype)
        not_done_mask = training_samples.not_terminal

        if self.use_seq_num_diff_as_time_diff:
            time_diff = training_samples.time_diffs / self.time_diff_unit_length
            discount_tensor = discount_tensor.pow(time_diff)

        all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
            training_samples.next_states)
        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                training_samples.possible_next_actions_mask,
            )
        else:
            # SARSA
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                training_samples.next_actions,
            )

        filtered_next_q_vals = next_q_values * not_done_mask

        if self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            target_q_values = rewards + (discount_tensor *
                                         filtered_next_q_vals)

        # Get Q-value of action taken
        all_q_values = self.q_network(states)
        self.all_action_scores = all_q_values.detach()
        q_values = torch.sum(all_q_values * actions, 1, keepdim=True)

        loss = self.q_network_loss(q_values, target_q_values)
        self.loss = loss.detach()

        self.q_network_optimizer.zero_grad()
        loss.backward()
        if self.gradient_handler:
            self.gradient_handler(self.q_network.parameters())
        self.q_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        logged_action_idxs = actions.argmax(dim=1, keepdim=True)
        reward_loss, model_rewards, model_propensities = self.calculate_cpes(
            training_samples,
            states,
            logged_action_idxs,
            max_q_action_idxs,
            discount_tensor,
            not_done_mask,
        )

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            logged_actions=logged_action_idxs,
            logged_propensities=training_samples.propensities,
            logged_rewards=rewards,
            logged_values=None,  # Compute at end of each epoch for CPE
            model_propensities=model_propensities,
            model_rewards=model_rewards,
            model_values=self.all_action_scores,
            model_values_on_logged_actions=
            None,  # Compute at end of each epoch for CPE
            model_action_idxs=self.get_max_q_values(
                self.all_action_scores,
                training_samples.possible_actions_mask)[1],
        )

    def calculate_cpes(
        self,
        training_samples,
        states,
        logged_action_idxs,
        max_q_action_idxs,
        discount_tensor,
        not_done_mask,
    ):
        if not self.calc_cpe_in_training:
            return None, None, None

        if training_samples.metrics is None:
            metrics_reward_concat_real_vals = training_samples.rewards
        else:
            metrics_reward_concat_real_vals = torch.cat(
                (training_samples.rewards, training_samples.metrics), dim=1)

        ######### Train separate reward network for CPE evaluation #############
        reward_estimates = self.reward_network(states)
        reward_estimates_for_logged_actions = reward_estimates.gather(
            1, self.reward_idx_offsets + logged_action_idxs)
        reward_loss = F.mse_loss(reward_estimates_for_logged_actions,
                                 metrics_reward_concat_real_vals)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        ######### Train separate q-network for CPE evaluation #############
        metric_q_values = self.q_network_cpe(states).gather(
            1, self.reward_idx_offsets + logged_action_idxs)
        metric_target_q_values = self.q_network_cpe_target(states).detach()
        max_q_values_metrics = metric_target_q_values.gather(
            1, self.reward_idx_offsets + max_q_action_idxs)
        filtered_max_q_values_metrics = max_q_values_metrics * not_done_mask
        if self.minibatch < self.reward_burnin:
            target_metric_q_values = metrics_reward_concat_real_vals
        else:
            target_metric_q_values = metrics_reward_concat_real_vals + (
                discount_tensor * filtered_max_q_values_metrics)
        metric_q_value_loss = self.q_network_loss(metric_q_values,
                                                  target_metric_q_values)
        self.q_network_cpe.zero_grad()
        metric_q_value_loss.backward()
        self.q_network_cpe_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network_cpe, self.q_network_cpe_target,
                              1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network_cpe, self.q_network_cpe_target,
                              self.tau)

        model_propensities = masked_softmax(
            self.all_action_scores,
            training_samples.possible_actions_mask,
            self.rl_temperature,
        )
        model_rewards = reward_estimates[:,
                                         torch.arange(
                                             self.reward_idx_offsets[0],
                                             self.reward_idx_offsets[0] +
                                             self.num_actions,
                                         ), ]
        return reward_loss, model_rewards, model_propensities

    def boost_rewards(self, rewards: torch.Tensor,
                      actions: torch.Tensor) -> torch.Tensor:
        # Apply reward boost if specified
        reward_boosts = torch.sum(actions.float() * self.reward_boosts,
                                  dim=1,
                                  keepdim=True)
        return rewards + reward_boosts

    def predictor(self, set_missing_value_to_zero=False) -> DQNPredictor:
        """Builds a DQNPredictor."""
        return DQNPredictor.export(
            self,
            self._actions,
            self.state_normalization_parameters,
            self._additional_feature_types.int_features,
            self.use_gpu,
            set_missing_value_to_zero=set_missing_value_to_zero,
        )

    def export(self) -> DQNPredictor:
        return self.predictor()
Example #4
0
class DQNTrainer(RLTrainer):
    def __init__(
        self,
        parameters: DiscreteActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self._actions = parameters.actions if parameters.actions is not None else []

        if parameters.training.cnn_parameters is None:
            self.state_normalization_parameters: Optional[Dict[
                int, NormalizationParameters]] = state_normalization_parameters
            self.num_features = get_num_output_features(
                state_normalization_parameters)
            logger.info("Number of state features: " + str(self.num_features))
            parameters.training.layers[0] = self.num_features
        else:
            self.state_normalization_parameters = None
        parameters.training.layers[-1] = self.num_actions

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types,
                           gradient_handler)

        self.reward_boosts = torch.zeros([1,
                                          len(self._actions)]).type(self.dtype)
        if parameters.rl.reward_boost is not None:
            for k in parameters.rl.reward_boost.keys():
                i = self._actions.index(k)
                self.reward_boosts[0, i] = parameters.rl.reward_boost[k]

        if parameters.rainbow.dueling_architecture:
            self.q_network = DuelingQNetwork(parameters.training.layers,
                                             parameters.training.activations)
        else:
            if parameters.training.cnn_parameters is None:
                self.q_network = FullyConnectedNetwork(
                    parameters.training.layers,
                    parameters.training.activations,
                    use_noisy_linear_layers=parameters.training.
                    use_noisy_linear_layers,
                )
            else:
                self.q_network = ConvolutionalNetwork(
                    parameters.training.cnn_parameters,
                    parameters.training.layers,
                    parameters.training.activations,
                    use_noisy_linear_layers=parameters.training.
                    use_noisy_linear_layers,
                )

        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        if parameters.training.cnn_parameters is None:
            self.reward_network = FullyConnectedNetwork(
                parameters.training.layers, parameters.training.activations)
        else:
            self.reward_network = ConvolutionalNetwork(
                parameters.training.cnn_parameters,
                parameters.training.layers,
                parameters.training.activations,
            )
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(
                    self.q_network_target)
                self.reward_network = torch.nn.DataParallel(
                    self.reward_network)

    @property
    def num_actions(self) -> int:
        return len(self._actions)

    def calculate_q_values(self, states):
        return self.q_network(states).detach()

    def get_max_q_values(self, states, possible_actions, double_q_learning):
        """
        Used in Q-learning update.
        :param states: Numpy array with shape (batch_size, state_dim). Each row
            contains a representation of a state.
        :param possible_actions: Numpy array with shape (batch_size, action_dim).
            possible_next_actions[i][j] = 1 iff the agent can take action j from
            state i.
        :param double_q_learning: bool to use double q-learning
        """
        if double_q_learning:
            q_values = self.q_network(states).detach()
            q_values_target = self.q_network_target(states).detach()
            # Set q-values of impossible actions to a very large negative number.
            inverse_pna = 1 - possible_actions
            impossible_action_penalty = self.ACTION_NOT_POSSIBLE_VAL * inverse_pna
            q_values += impossible_action_penalty
            # Select max_q action after scoring with online network
            max_q_values, max_indicies = torch.max(q_values,
                                                   dim=1,
                                                   keepdim=True)
            # Use q_values from target network for max_q action from online q_network
            # to decouble selection & scoring, preventing overestimation of q-values
            q_values = torch.gather(q_values_target, 1, max_indicies)
            return q_values
        else:
            q_values = self.q_network_target(states).detach()
            # Set q-values of impossible actions to a very large negative number.
            inverse_pna = 1 - possible_actions
            impossible_action_penalty = self.ACTION_NOT_POSSIBLE_VAL * inverse_pna
            q_values += impossible_action_penalty
            return torch.max(q_values, dim=1, keepdim=True)[0]

    def get_next_action_q_values(self, states, next_actions):
        """
        Used in SARSA update.
        :param states: Numpy array with shape (batch_size, state_dim). Each row
            contains a representation of a state.
        :param next_actions: Numpy array with shape (batch_size, action_dim).
        """
        q_values = self.q_network_target(states).detach()
        return torch.sum(q_values * next_actions, dim=1, keepdim=True)

    def train(self,
              training_samples: TrainingDataPage,
              evaluator: Optional[Evaluator] = None):

        if self.minibatch == 0:
            # Assume that the tensors are the right shape after the first minibatch
            assert (training_samples.states.shape[0] == self.minibatch_size
                    ), "Invalid shape: " + str(training_samples.states.shape)
            assert training_samples.actions.shape == torch.Size([
                self.minibatch_size, len(self._actions)
            ]), "Invalid shape: " + str(training_samples.actions.shape)
            assert training_samples.rewards.shape == torch.Size(
                [self.minibatch_size,
                 1]), "Invalid shape: " + str(training_samples.rewards.shape)
            assert (training_samples.next_states.shape ==
                    training_samples.states.shape), "Invalid shape: " + str(
                        training_samples.next_states.shape)
            assert (training_samples.not_terminals.shape ==
                    training_samples.rewards.shape), "Invalid shape: " + str(
                        training_samples.not_terminals.shape)
            if training_samples.possible_next_actions is not None:
                assert (
                    training_samples.possible_next_actions.shape ==
                    training_samples.actions.shape), "Invalid shape: " + str(
                        training_samples.possible_next_actions.shape)
            if training_samples.propensities is not None:
                assert (training_samples.propensities.shape == training_samples
                        .rewards.shape), "Invalid shape: " + str(
                            training_samples.propensities.shape)

        # Apply reward boost if specified
        reward_boosts = torch.sum(training_samples.actions.float() *
                                  self.reward_boosts,
                                  dim=1,
                                  keepdim=True)
        boosted_rewards = training_samples.rewards + reward_boosts

        self.minibatch += 1
        states = training_samples.states.detach().requires_grad_(True)
        actions = training_samples.actions
        rewards = boosted_rewards
        next_states = training_samples.next_states
        discount_tensor = torch.full(training_samples.time_diffs.shape,
                                     self.gamma).type(self.dtype)
        not_done_mask = training_samples.not_terminals

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(training_samples.time_diffs)

        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            possible_next_actions = training_samples.possible_next_actions
            next_q_values = self.get_max_q_values(next_states,
                                                  possible_next_actions,
                                                  self.double_q_learning)
        else:
            # SARSA
            next_actions = training_samples.next_actions
            next_q_values = self.get_next_action_q_values(
                next_states, next_actions)

        filtered_next_q_vals = next_q_values * not_done_mask

        if self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            target_q_values = rewards + (discount_tensor *
                                         filtered_next_q_vals)

        # Get Q-value of action taken
        all_q_values = self.q_network(states)
        self.all_action_scores = all_q_values.detach()
        q_values = torch.sum(all_q_values * actions, 1, keepdim=True)

        loss = self.q_network_loss(q_values, target_q_values)
        self.loss = loss.detach()

        self.q_network_optimizer.zero_grad()
        loss.backward()
        if self.gradient_handler:
            self.gradient_handler(self.q_network.parameters())
        self.q_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        # get reward estimates
        reward_estimates = self.reward_network(states)
        self.reward_estimates = reward_estimates.detach()
        reward_estimates_for_logged_actions = reward_estimates.gather(
            1, actions.argmax(dim=1, keepdim=True))
        reward_loss = F.mse_loss(reward_estimates_for_logged_actions, rewards)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        self.loss_reporter.report(td_loss=float(self.loss),
                                  reward_loss=float(reward_loss))

        training_metadata = {}
        if evaluator is not None:

            model_propensities = torch.from_numpy(
                Evaluator.softmax(self.all_action_scores.cpu().numpy(),
                                  self.rl_temperature))

            cpe_stats = BatchStatsForCPE(
                logged_actions=training_samples.actions,
                logged_propensities=training_samples.propensities,
                logged_rewards=rewards,
                logged_values=None,  # Compute at end of each epoch for CPE
                model_propensities=model_propensities,
                model_rewards=self.reward_estimates,
                model_values=self.all_action_scores,
                model_values_on_logged_actions=
                None,  # Compute at end of each epoch for CPE
                model_action_idxs=self.all_action_scores.argmax(dim=1,
                                                                keepdim=True),
            )
            evaluator.report(cpe_stats)
            training_metadata["model_rewards"] = self.reward_estimates.cpu(
            ).numpy()

        return training_metadata

    def predictor(self) -> DQNPredictor:
        """Builds a DQNPredictor."""
        return DQNPredictor.export(
            self,
            self._actions,
            self.state_normalization_parameters,
            self._additional_feature_types.int_features,
            self.use_gpu,
        )

    def export(self) -> DQNPredictor:
        return self.predictor()