Esempio n. 1
0
def _get_sac_trainer_params(env, sac_model_params, use_gpu):
    state_dim = get_num_output_features(env.normalization)
    action_dim = get_num_output_features(env.normalization_action)
    q1_network = FullyConnectedParametricDQN(
        state_dim,
        action_dim,
        sac_model_params.q_network.layers,
        sac_model_params.q_network.activations,
    )
    q2_network = None
    if sac_model_params.training.use_2_q_functions:
        q2_network = FullyConnectedParametricDQN(
            state_dim,
            action_dim,
            sac_model_params.q_network.layers,
            sac_model_params.q_network.activations,
        )
    value_network = FullyConnectedNetwork(
        [state_dim] + sac_model_params.value_network.layers + [1],
        sac_model_params.value_network.activations + ["linear"],
    )
    actor_network = GaussianFullyConnectedActor(
        state_dim,
        action_dim,
        sac_model_params.actor_network.layers,
        sac_model_params.actor_network.activations,
    )
    if use_gpu:
        q1_network.cuda()
        if q2_network:
            q2_network.cuda()
        value_network.cuda()
        actor_network.cuda()
    value_network_target = deepcopy(value_network)
    min_action_range_tensor_training = torch.full((1, action_dim), -1 + 1e-6)
    max_action_range_tensor_training = torch.full((1, action_dim), 1 - 1e-6)
    action_range_low = env.action_space.low.astype(np.float32)
    action_range_high = env.action_space.high.astype(np.float32)
    min_action_range_tensor_serving = torch.from_numpy(action_range_low).unsqueeze(
        dim=0
    )
    max_action_range_tensor_serving = torch.from_numpy(action_range_high).unsqueeze(
        dim=0
    )

    trainer_args = [
        q1_network,
        value_network,
        value_network_target,
        actor_network,
        sac_model_params,
    ]
    trainer_kwargs = {
        "q2_network": q2_network,
        "min_action_range_tensor_training": min_action_range_tensor_training,
        "max_action_range_tensor_training": max_action_range_tensor_training,
        "min_action_range_tensor_serving": min_action_range_tensor_serving,
        "max_action_range_tensor_serving": max_action_range_tensor_serving,
    }
    return trainer_args, trainer_kwargs
Esempio n. 2
0
class ConvolutionalNetwork(nn.Module):
    def __init__(
        self, cnn_parameters, layers, activations, use_noisy_linear_layers=False
    ) -> None:
        super(ConvolutionalNetwork, self).__init__()
        self.conv_dims = cnn_parameters.conv_dims
        self.conv_height_kernels = cnn_parameters.conv_height_kernels
        self.conv_width_kernels = cnn_parameters.conv_width_kernels
        self.conv_layers: nn.ModuleList = nn.ModuleList()
        self.pool_layers: nn.ModuleList = nn.ModuleList()

        for i, _ in enumerate(self.conv_dims[1:]):
            self.conv_layers.append(
                nn.Conv2d(
                    self.conv_dims[i],
                    self.conv_dims[i + 1],
                    kernel_size=(
                        self.conv_height_kernels[i],
                        self.conv_width_kernels[i],
                    ),
                )
            )
            nn.init.kaiming_normal_(self.conv_layers[i].weight)
            if cnn_parameters.pool_types[i] == "max":
                self.pool_layers.append(
                    nn.MaxPool2d(kernel_size=cnn_parameters.pool_kernels_strides[i])
                )
            else:
                assert False, "Unknown pooling type".format(layers)

        input_size = (
            cnn_parameters.num_input_channels,
            cnn_parameters.input_height,
            cnn_parameters.input_width,
        )
        conv_out = self.conv_forward(torch.ones(1, *input_size))
        self.fc_input_dim = int(np.prod(conv_out.size()[1:]))
        layers[0] = self.fc_input_dim
        self.feed_forward = FullyConnectedNetwork(
            layers, activations, use_noisy_linear_layers=use_noisy_linear_layers
        )

    def conv_forward(self, input):
        x = input
        for i, _ in enumerate(self.conv_layers):
            x = F.relu(self.conv_layers[i](x))
            x = self.pool_layers[i](x)
        return x

    def forward(self, input) -> torch.FloatTensor:
        """ Forward pass for generic convnet DNNs. Assumes activation names
        are valid pytorch activation names.
        :param input image tensor
        """
        x = self.conv_forward(input)
        x = x.view(-1, self.fc_input_dim)
        return self.feed_forward.forward(x)
Esempio n. 3
0
    def _init_bcq_network(self, parameters, use_all_avail_gpus):
        # Batch constrained q-learning
        if not parameters.rainbow.bcq:
            return

        self.bcq_imitator = FullyConnectedNetwork(
            parameters.training.layers,
            parameters.training.activations,
            min_std=parameters.training.weight_init_min_std,
            use_batch_norm=parameters.training.use_batch_norm,
        )
        self.bcq_imitator_optimizer = self.optimizer_func(
            self.bcq_imitator.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        if self.use_gpu:
            self.bcq_imitator.cuda()
Esempio n. 4
0
 def build_value_network(
         self,
         state_normalization_data: NormalizationData) -> torch.nn.Module:
     state_dim = get_num_output_features(
         state_normalization_data.dense_normalization_parameters)
     return FullyConnectedNetwork(
         [state_dim] + self.sizes + [1],
         self.activations + ["linear"],
         use_layer_norm=self.use_layer_norm,
     )
    def __init__(
        self, cnn_parameters, layers, activations, use_noisy_linear_layers=False
    ) -> None:
        super(ConvolutionalNetwork, self).__init__()
        self.conv_dims = cnn_parameters.conv_dims
        self.conv_height_kernels = cnn_parameters.conv_height_kernels
        self.conv_width_kernels = cnn_parameters.conv_width_kernels
        self.conv_layers: nn.ModuleList = nn.ModuleList()
        self.pool_layers: nn.ModuleList = nn.ModuleList()

        for i, _ in enumerate(self.conv_dims[1:]):
            self.conv_layers.append(
                nn.Conv2d(
                    self.conv_dims[i],
                    self.conv_dims[i + 1],
                    kernel_size=(
                        self.conv_height_kernels[i],
                        self.conv_width_kernels[i],
                    ),
                )
            )
            nn.init.kaiming_normal_(self.conv_layers[i].weight)
            if cnn_parameters.pool_types[i] == "max":
                self.pool_layers.append(
                    nn.MaxPool2d(kernel_size=cnn_parameters.pool_kernels_strides[i])
                )
            else:
                assert False, "Unknown pooling type".format(layers)

        input_size = (
            cnn_parameters.num_input_channels,
            cnn_parameters.input_height,
            cnn_parameters.input_width,
        )
        conv_out = self.conv_forward(torch.ones(1, *input_size))
        self.fc_input_dim = int(np.prod(conv_out.size()[1:]))
        layers[0] = self.fc_input_dim
        self.feed_forward = FullyConnectedNetwork(
            layers, activations, use_noisy_linear_layers=use_noisy_linear_layers
        )
class ConvolutionalNetwork(nn.Module):
    def __init__(self, cnn_parameters, layers, activations) -> None:
        super(ConvolutionalNetwork, self).__init__()
        self.conv_dims = cnn_parameters.conv_dims
        self.conv_height_kernels = cnn_parameters.conv_height_kernels
        self.conv_width_kernels = cnn_parameters.conv_width_kernels
        self.conv_layers: nn.ModuleList = nn.ModuleList()
        self.pool_layers: nn.ModuleList = nn.ModuleList()

        for i, _ in enumerate(self.conv_dims[1:]):
            self.conv_layers.append(
                nn.Conv2d(
                    self.conv_dims[i],
                    self.conv_dims[i + 1],
                    kernel_size=(
                        self.conv_height_kernels[i],
                        self.conv_width_kernels[i],
                    ),
                ))
            nn.init.kaiming_normal_(self.conv_layers[i].weight)
            if cnn_parameters.pool_types[i] == "max":
                self.pool_layers.append(
                    nn.MaxPool2d(
                        kernel_size=cnn_parameters.pool_kernels_strides[i]))
            else:
                assert False, "Unknown pooling type".format(layers)

        input_size = (
            cnn_parameters.num_input_channels,
            cnn_parameters.input_height,
            cnn_parameters.input_width,
        )
        conv_out = self.conv_forward(torch.ones(1, *input_size))
        self.fc_input_dim = int(np.prod(conv_out.size()[1:]))
        layers[0] = self.fc_input_dim
        self.feed_forward = FullyConnectedNetwork(layers, activations)

    def conv_forward(self, input):
        x = input
        for i, _ in enumerate(self.conv_layers):
            x = F.relu(self.conv_layers[i](x))
            x = self.pool_layers[i](x)
        return x

    def forward(self, input) -> torch.FloatTensor:
        """ Forward pass for generic convnet DNNs. Assumes activation names
        are valid pytorch activation names.
        :param input image tensor
        """
        x = self.conv_forward(input)
        x = x.view(-1, self.fc_input_dim)
        return self.feed_forward.forward(x)
    def __init__(
        self,
        state_dim,
        action_dim,
        sizes,
        activations,
        model_feature_config: rlt.ModelFeatureConfig,
        embedding_dim: int,
        use_batch_norm=False,
        dropout_ratio=0.0,
    ):
        super().__init__()
        assert state_dim > 0, "state_dim must be > 0, got {}".format(state_dim)
        assert action_dim > 0, "action_dim must be > 0, got {}".format(
            action_dim)
        self.state_dim = state_dim
        self.action_dim = action_dim
        assert len(sizes) == len(
            activations
        ), "The numbers of sizes and activations must match; got {} vs {}".format(
            len(sizes), len(activations))

        self.embedding_bags = torch.nn.ModuleDict(  # type: ignore
            {
                id_list_feature.name: torch.nn.EmbeddingBag(  # type: ignore
                    len(
                        model_feature_config.id_mapping_config[
                            id_list_feature.id_mapping_name
                        ].ids
                    ),
                    embedding_dim,
                )
                for id_list_feature in model_feature_config.id_list_feature_configs
            }
        )

        fc_input_dim = (
            state_dim +
            len(model_feature_config.id_list_feature_configs) * embedding_dim)

        self.fc = FullyConnectedNetwork(
            [fc_input_dim] + sizes + [action_dim],
            activations + ["linear"],
            use_batch_norm=use_batch_norm,
            dropout_ratio=dropout_ratio,
        )
Esempio n. 8
0
 def test_save_load(self):
     state_dim = 8
     action_dim = 4
     q_network = FullyConnectedDQN(state_dim,
                                   action_dim,
                                   sizes=[8, 4],
                                   activations=["relu", "relu"])
     imitator_network = FullyConnectedNetwork(
         layers=[state_dim, 8, 4, action_dim],
         activations=["relu", "relu", "linear"])
     model = BatchConstrainedDQN(
         state_dim=state_dim,
         q_network=q_network,
         imitator_network=imitator_network,
         bcq_drop_threshold=0.05,
     )
     # 6 for DQN + 6 for Imitator Network + 2 for BCQ constants
     expected_num_params, expected_num_inputs, expected_num_outputs = 14, 1, 1
     check_save_load(self, model, expected_num_params, expected_num_inputs,
                     expected_num_outputs)
Esempio n. 9
0
    def test_basic(self):
        state_dim = 8
        action_dim = 4
        q_network = FullyConnectedDQN(state_dim,
                                      action_dim,
                                      sizes=[8, 4],
                                      activations=["relu", "relu"])
        imitator_network = FullyConnectedNetwork(
            layers=[state_dim, 8, 4, action_dim],
            activations=["relu", "relu", "linear"])
        model = BatchConstrainedDQN(
            state_dim=state_dim,
            q_network=q_network,
            imitator_network=imitator_network,
            bcq_drop_threshold=0.05,
        )

        input = model.input_prototype()
        self.assertEqual((1, state_dim), input.state.float_features.shape)
        q_values = model(input)
        self.assertEqual((1, action_dim), q_values.q_values.shape)
Esempio n. 10
0
 def __init__(self,
              state_dim,
              action_dim,
              sizes,
              activations,
              use_batch_norm=False):
     super().__init__()
     assert state_dim > 0, "state_dim must be > 0, got {}".format(state_dim)
     assert action_dim > 0, "action_dim must be > 0, got {}".format(
         action_dim)
     self.state_dim = state_dim
     self.action_dim = action_dim
     assert len(sizes) == len(
         activations
     ), "The numbers of sizes and activations must match; got {} vs {}".format(
         len(sizes), len(activations))
     self.fc = FullyConnectedNetwork(
         [state_dim + action_dim] + sizes + [1],
         activations + ["linear"],
         use_batch_norm=use_batch_norm,
     )
Esempio n. 11
0
    def __init__(self, state_dim, action_dim, sizes, activations, use_batch_norm=False):
        """
        AKA the multivariate beta distribution. Used in cases where actor's action
        must sum to 1.
        """
        super().__init__()
        assert state_dim > 0, "state_dim must be > 0, got {}".format(state_dim)
        assert action_dim > 0, "action_dim must be > 0, got {}".format(action_dim)
        self.state_dim = state_dim
        self.action_dim = action_dim
        assert len(sizes) == len(
            activations
        ), "The numbers of sizes and activations must match; got {} vs {}".format(
            len(sizes), len(activations)
        )

        # The last layer gives the concentration of the distribution.
        self.fc = FullyConnectedNetwork(
            [state_dim] + sizes + [action_dim],
            activations + ["relu"],
            use_batch_norm=use_batch_norm,
        )
Esempio n. 12
0
    def __init__(
        self,
        state_dim,
        action_dim,
        sizes,
        activations,
        scale=0.05,
        use_batch_norm=False,
        use_layer_norm=False,
    ):
        super().__init__()
        assert state_dim > 0, "state_dim must be > 0, got {}".format(state_dim)
        assert action_dim > 0, "action_dim must be > 0, got {}".format(
            action_dim)
        self.state_dim = state_dim
        self.action_dim = action_dim
        assert len(sizes) == len(
            activations
        ), "The numbers of sizes and activations must match; got {} vs {}".format(
            len(sizes), len(activations))
        # The last layer is mean & scale for reparameterization trick
        self.fc = FullyConnectedNetwork(
            [state_dim] + sizes + [action_dim * 2],
            activations + ["linear"],
            use_batch_norm=use_batch_norm,
            use_layer_norm=use_layer_norm,
        )
        self.use_layer_norm = use_layer_norm
        if self.use_layer_norm:
            self.loc_layer_norm = nn.LayerNorm(action_dim)
            self.scale_layer_norm = nn.LayerNorm(action_dim)

        # used to calculate log-prob
        self.const = math.log(math.sqrt(2 * math.pi))
        self.eps = 1e-6
        self._log_min_max = (-20.0, 2.0)
Esempio n. 13
0
class ParametricDQNTrainer(RLTrainer):
    def __init__(
        self,
        parameters: ContinuousActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters
        self.num_state_features = get_num_output_features(
            state_normalization_parameters)
        self.num_action_features = get_num_output_features(
            action_normalization_parameters)
        self.num_features = self.num_state_features + self.num_action_features

        # ensure state and action IDs have no intersection
        overlapping_features = set(
            state_normalization_parameters.keys()) & set(
                action_normalization_parameters.keys())
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: " +
            str(overlapping_features))

        reward_network_layers = deepcopy(parameters.training.layers)
        reward_network_layers[0] = self.num_features
        reward_network_layers[-1] = 1

        if parameters.rainbow.dueling_architecture:
            parameters.training.layers[0] = self.num_state_features
            parameters.training.layers[-1] = 1
        elif parameters.training.factorization_parameters is None:
            parameters.training.layers[0] = self.num_features
            parameters.training.layers[-1] = 1
        else:
            parameters.training.factorization_parameters.state.layers[
                0] = self.num_state_features
            parameters.training.factorization_parameters.action.layers[
                0] = self.num_action_features

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types,
                           gradient_handler)

        self.q_network = self._get_model(
            parameters.training, parameters.rainbow.dueling_architecture)

        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        self.reward_network = FullyConnectedNetwork(
            reward_network_layers, parameters.training.activations)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(
                    self.q_network_target)
                self.reward_network = torch.nn.DataParallel(
                    self.reward_network)

    def _get_model(self, training_parameters, dueling_architecture=False):
        if dueling_architecture:
            return DuelingQNetwork(
                training_parameters.layers,
                training_parameters.activations,
                action_dim=self.num_action_features,
            )
        elif training_parameters.factorization_parameters is None:
            return FullyConnectedNetwork(
                training_parameters.layers,
                training_parameters.activations,
                use_noisy_linear_layers=training_parameters.
                use_noisy_linear_layers,
            )
        else:
            return ParametricInnerProduct(
                FullyConnectedNetwork(
                    training_parameters.factorization_parameters.state.layers,
                    training_parameters.factorization_parameters.state.
                    activations,
                ),
                FullyConnectedNetwork(
                    training_parameters.factorization_parameters.action.layers,
                    training_parameters.factorization_parameters.action.
                    activations,
                ),
                self.num_state_features,
                self.num_action_features,
            )

    def calculate_q_values(self, state_pas_concats, pas_lens):
        row_nums = np.arange(pas_lens.shape[0], dtype=np.int64)
        row_idxs = np.repeat(row_nums, pas_lens.cpu().numpy())
        col_idxs = arange_expand(pas_lens)

        dense_idxs = torch.LongTensor(
            (row_idxs, col_idxs)).type(self.dtypelong)

        q_values = self.q_network(state_pas_concats).squeeze().detach()

        dense_dim = [len(pas_lens), int(torch.max(pas_lens))]
        # Add specific fingerprint to q-values so that after sparse -> dense we can
        # subtract the fingerprint to identify the 0's added in sparse -> dense
        q_values.add_(self.FINGERPRINT)
        sparse_q = torch.sparse_coo_tensor(dense_idxs,
                                           q_values,
                                           size=dense_dim)
        dense_q = sparse_q.to_dense()
        dense_q.add_(self.FINGERPRINT * -1)
        dense_q[dense_q == self.FINGERPRINT *
                -1] = self.ACTION_NOT_POSSIBLE_VAL

        return dense_q

    def get_max_q_values(self, possible_next_actions_state_concat, pnas_lens,
                         double_q_learning):
        """
        :param possible_next_actions_state_concat: Numpy array with shape
            (sum(pnas_lens), state_dim + action_dim). Each row
            contains a representation of a state + possible next action pair.
        :param pnas_lens: Numpy array that describes number of
            possible_actions per item in minibatch
        :param double_q_learning: bool to use double q-learning
        """
        row_nums = np.arange(len(pnas_lens))
        row_idxs = np.repeat(row_nums, pnas_lens.cpu().numpy())
        col_idxs = arange_expand(pnas_lens).cpu().numpy()

        dense_idxs = torch.LongTensor(
            (row_idxs, col_idxs)).type(self.dtypelong)
        if isinstance(possible_next_actions_state_concat, torch.Tensor):
            q_network_input = possible_next_actions_state_concat
        else:
            q_network_input = torch.from_numpy(
                possible_next_actions_state_concat).type(self.dtype)

        if double_q_learning:
            q_values = self.q_network(q_network_input).squeeze().detach()
            q_values_target = self.q_network_target(
                q_network_input).squeeze().detach()
        else:
            q_values = self.q_network_target(
                q_network_input).squeeze().detach()

        dense_dim = [len(pnas_lens), max(pnas_lens)]
        # Add specific fingerprint to q-values so that after sparse -> dense we can
        # subtract the fingerprint to identify the 0's added in sparse -> dense
        q_values.add_(self.FINGERPRINT)
        sparse_q = torch.sparse_coo_tensor(dense_idxs, q_values, dense_dim)
        dense_q = sparse_q.to_dense()
        dense_q.add_(self.FINGERPRINT * -1)
        dense_q[dense_q == self.FINGERPRINT *
                -1] = self.ACTION_NOT_POSSIBLE_VAL
        max_q_values, max_indexes = torch.max(dense_q, dim=1)

        if double_q_learning:
            sparse_q_target = torch.sparse_coo_tensor(dense_idxs,
                                                      q_values_target,
                                                      dense_dim)
            dense_q_values_target = sparse_q_target.to_dense()
            max_q_values = torch.gather(dense_q_values_target, 1,
                                        max_indexes.unsqueeze(1))

        return max_q_values.squeeze()

    def get_next_action_q_values(self, state_action_pairs):
        return self.q_network_target(state_action_pairs)

    def train(self,
              training_samples: TrainingDataPage,
              evaluator=None) -> None:
        if self.minibatch == 0:
            # Assume that the tensors are the right shape after the first minibatch
            assert (training_samples.states.shape[0] == self.minibatch_size
                    ), "Invalid shape: " + str(training_samples.states.shape)
            assert (training_samples.actions.shape[0] == self.minibatch_size
                    ), "Invalid shape: " + str(training_samples.actions.shape)
            assert training_samples.rewards.shape == torch.Size(
                [self.minibatch_size,
                 1]), "Invalid shape: " + str(training_samples.rewards.shape)
            assert (training_samples.next_states.shape ==
                    training_samples.states.shape), "Invalid shape: " + str(
                        training_samples.next_states.shape)
            assert (training_samples.not_terminals.shape ==
                    training_samples.rewards.shape), "Invalid shape: " + str(
                        training_samples.not_terminals.shape)
            assert training_samples.possible_next_actions_state_concat.shape[
                1] == (
                    training_samples.states.shape[1] +
                    training_samples.actions.shape[1]
                ), ("Invalid shape: " + str(
                    training_samples.possible_next_actions_state_concat.shape))
            assert training_samples.possible_next_actions_lengths.shape == torch.Size(
                [
                    self.minibatch_size
                ]), ("Invalid shape: " +
                     str(training_samples.possible_next_actions_lengths.shape))

        self.minibatch += 1

        states = training_samples.states.detach().requires_grad_(True)
        actions = training_samples.actions
        state_action_pairs = torch.cat((states, actions), dim=1)

        rewards = training_samples.rewards
        discount_tensor = torch.full(training_samples.time_diffs.shape,
                                     self.gamma).type(self.dtype)
        not_done_mask = training_samples.not_terminals

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(training_samples.time_diffs)

        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values = self.get_max_q_values(
                training_samples.possible_next_actions_state_concat,
                training_samples.possible_next_actions_lengths,
                self.double_q_learning,
            )
        else:
            # SARSA
            next_state_action_pairs = torch.cat(
                (training_samples.next_states, training_samples.next_actions),
                dim=1)
            next_q_values = self.get_next_action_q_values(
                next_state_action_pairs)

        filtered_max_q_vals = next_q_values.reshape(-1, 1) * not_done_mask

        if self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            target_q_values = rewards + (discount_tensor * filtered_max_q_vals)

        # Get Q-value of action taken
        q_values = self.q_network(state_action_pairs)
        all_action_scores = q_values.detach()
        self.model_values_on_logged_actions = q_values.detach()

        value_loss = self.q_network_loss(q_values, target_q_values)
        self.loss = value_loss.detach()

        self.q_network_optimizer.zero_grad()
        value_loss.backward()
        if self.gradient_handler:
            self.gradient_handler(self.q_network.parameters())
        self.q_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        # get reward estimates
        reward_estimates = self.reward_network(state_action_pairs)
        reward_loss = F.mse_loss(reward_estimates, rewards)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        self.loss_reporter.report(td_loss=float(self.loss),
                                  reward_loss=float(reward_loss))

        if evaluator is not None:
            cpe_stats = BatchStatsForCPE(
                model_values_on_logged_actions=all_action_scores)
            evaluator.report(cpe_stats)

    def predictor(self) -> ParametricDQNPredictor:
        """Builds a ParametricDQNPredictor."""
        return ParametricDQNPredictor.export(
            self,
            self.state_normalization_parameters,
            self.action_normalization_parameters,
            self._additional_feature_types.int_features,
            self.use_gpu,
        )

    def export(self) -> ParametricDQNPredictor:
        return self.predictor()
Esempio n. 14
0
def get_sac_trainer(
    env: OpenAIGymEnvironment,
    rl_parameters: RLParameters,
    trainer_parameters: SACTrainerParameters,
    critic_training: FeedForwardParameters,
    actor_training: FeedForwardParameters,
    sac_value_training: Optional[FeedForwardParameters],
    use_gpu: bool,
) -> SACTrainer:
    assert rl_parameters == trainer_parameters.rl
    state_dim = get_num_output_features(env.normalization)
    action_dim = get_num_output_features(env.normalization_action)
    q1_network = FullyConnectedParametricDQN(state_dim, action_dim,
                                             critic_training.layers,
                                             critic_training.activations)
    q2_network = None
    # TODO:
    # if trainer_parameters.use_2_q_functions:
    #     q2_network = FullyConnectedParametricDQN(
    #         state_dim,
    #         action_dim,
    #         critic_training.layers,
    #         critic_training.activations,
    #     )
    value_network = None
    if sac_value_training:
        value_network = FullyConnectedNetwork(
            [state_dim] + sac_value_training.layers + [1],
            sac_value_training.activations + ["linear"],
        )
    actor_network = GaussianFullyConnectedActor(state_dim, action_dim,
                                                actor_training.layers,
                                                actor_training.activations)

    min_action_range_tensor_training = torch.full((1, action_dim), -1 + 1e-6)
    max_action_range_tensor_training = torch.full((1, action_dim), 1 - 1e-6)
    min_action_range_tensor_serving = (
        torch.from_numpy(env.action_space.low).float().unsqueeze(
            dim=0)  # type: ignore
    )
    max_action_range_tensor_serving = (
        torch.from_numpy(env.action_space.high).float().unsqueeze(
            dim=0)  # type: ignore
    )

    if use_gpu:
        q1_network.cuda()
        if q2_network:
            q2_network.cuda()
        if value_network:
            value_network.cuda()
        actor_network.cuda()

        min_action_range_tensor_training = min_action_range_tensor_training.cuda(
        )
        max_action_range_tensor_training = max_action_range_tensor_training.cuda(
        )
        min_action_range_tensor_serving = min_action_range_tensor_serving.cuda(
        )
        max_action_range_tensor_serving = max_action_range_tensor_serving.cuda(
        )

    return SACTrainer(
        q1_network,
        actor_network,
        trainer_parameters,
        use_gpu=use_gpu,
        value_network=value_network,
        q2_network=q2_network,
        min_action_range_tensor_training=min_action_range_tensor_training,
        max_action_range_tensor_training=max_action_range_tensor_training,
        min_action_range_tensor_serving=min_action_range_tensor_serving,
        max_action_range_tensor_serving=max_action_range_tensor_serving,
    )
Esempio n. 15
0
    def __init__(
        self,
        parameters: ContinuousActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters
        self.num_state_features = get_num_output_features(
            state_normalization_parameters)
        self.num_action_features = get_num_output_features(
            action_normalization_parameters)
        self.num_features = self.num_state_features + self.num_action_features

        # ensure state and action IDs have no intersection
        overlapping_features = set(
            state_normalization_parameters.keys()) & set(
                action_normalization_parameters.keys())
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: " +
            str(overlapping_features))

        reward_network_layers = deepcopy(parameters.training.layers)
        reward_network_layers[0] = self.num_features
        reward_network_layers[-1] = 1

        if parameters.rainbow.dueling_architecture:
            parameters.training.layers[0] = self.num_state_features
            parameters.training.layers[-1] = 1
        elif parameters.training.factorization_parameters is None:
            parameters.training.layers[0] = self.num_features
            parameters.training.layers[-1] = 1
        else:
            parameters.training.factorization_parameters.state.layers[
                0] = self.num_state_features
            parameters.training.factorization_parameters.action.layers[
                0] = self.num_action_features

        RLTrainer.__init__(self, parameters, use_gpu, additional_feature_types,
                           gradient_handler)

        self.q_network = self._get_model(
            parameters.training, parameters.rainbow.dueling_architecture)

        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        self.reward_network = FullyConnectedNetwork(
            reward_network_layers, parameters.training.activations)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(
                    self.q_network_target)
                self.reward_network = torch.nn.DataParallel(
                    self.reward_network)
Esempio n. 16
0
class DQNTrainer(DQNTrainerBase):
    def __init__(
        self,
        parameters: DiscreteActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        metrics_to_score=None,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.bcq = parameters.rainbow.bcq
        self.bcq_drop_threshold = parameters.rainbow.bcq_drop_threshold
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self._actions = parameters.actions if parameters.actions is not None else []

        if parameters.training.cnn_parameters is None:
            self.state_normalization_parameters: Optional[Dict[
                int, NormalizationParameters]] = state_normalization_parameters
            self.num_features = get_num_output_features(
                state_normalization_parameters)
            logger.info("Number of state features: " + str(self.num_features))
            parameters.training.layers[0] = self.num_features
        else:
            self.state_normalization_parameters = None
        parameters.training.layers[-1] = self.num_actions

        RLTrainer.__init__(
            self,
            parameters,
            use_gpu,
            metrics_to_score,
            gradient_handler,
            actions=self._actions,
        )

        self.reward_boosts = torch.zeros([1,
                                          len(self._actions)]).type(self.dtype)
        if parameters.rl.reward_boost is not None:
            for k in parameters.rl.reward_boost.keys():
                i = self._actions.index(k)
                self.reward_boosts[0, i] = parameters.rl.reward_boost[k]

        if parameters.rainbow.dueling_architecture:
            self.q_network = DuelingQNetwork(
                parameters.training.layers,
                parameters.training.activations,
                use_batch_norm=parameters.training.use_batch_norm,
            )
        else:
            if parameters.training.cnn_parameters is None:
                self.q_network = FullyConnectedNetwork(
                    parameters.training.layers,
                    parameters.training.activations,
                    use_noisy_linear_layers=parameters.training.
                    use_noisy_linear_layers,
                    min_std=parameters.training.weight_init_min_std,
                    use_batch_norm=parameters.training.use_batch_norm,
                )
            else:
                self.q_network = ConvolutionalNetwork(
                    parameters.training.cnn_parameters,
                    parameters.training.layers,
                    parameters.training.activations,
                    use_noisy_linear_layers=parameters.training.
                    use_noisy_linear_layers,
                    min_std=parameters.training.weight_init_min_std,
                    use_batch_norm=parameters.training.use_batch_norm,
                )

        self.q_network_target = deepcopy(self.q_network)
        self.q_network._name = "training"
        self.q_network_target._name = "target"
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )
        self.clip_grad_norm = parameters.training.clip_grad_norm

        self._init_cpe_networks(parameters, use_all_avail_gpus)
        self._init_bcq_network(parameters, use_all_avail_gpus)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(
                    self.q_network_target)

    def _init_bcq_network(self, parameters, use_all_avail_gpus):
        # Batch constrained q-learning
        if not parameters.rainbow.bcq:
            return

        self.bcq_imitator = FullyConnectedNetwork(
            parameters.training.layers,
            parameters.training.activations,
            min_std=parameters.training.weight_init_min_std,
            use_batch_norm=parameters.training.use_batch_norm,
        )
        self.bcq_imitator_optimizer = self.optimizer_func(
            self.bcq_imitator.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        if self.use_gpu:
            self.bcq_imitator.cuda()

    def _init_cpe_networks(self, parameters, use_all_avail_gpus):
        if not self.calc_cpe_in_training:
            return

        reward_network_layers = deepcopy(parameters.training.layers)
        if self.metrics_to_score:
            num_output_nodes = len(self.metrics_to_score) * self.num_actions
        else:
            num_output_nodes = self.num_actions

        reward_network_layers[-1] = num_output_nodes
        self.reward_idx_offsets = torch.arange(
            0, num_output_nodes, self.num_actions).type(self.dtypelong)
        logger.info("Reward network for CPE will have {} output nodes.".format(
            num_output_nodes))

        if parameters.training.cnn_parameters is None:
            self.reward_network = FullyConnectedNetwork(
                reward_network_layers, parameters.training.activations)
            self.q_network_cpe = FullyConnectedNetwork(
                reward_network_layers, parameters.training.activations)
        else:
            self.reward_network = ConvolutionalNetwork(
                parameters.training.cnn_parameters,
                reward_network_layers,
                parameters.training.activations,
            )
            self.q_network_cpe = ConvolutionalNetwork(
                parameters.training.cnn_parameters,
                reward_network_layers,
                parameters.training.activations,
            )
        self.q_network_cpe_target = deepcopy(self.q_network_cpe)
        self.q_network_cpe_optimizer = self.optimizer_func(
            self.q_network_cpe.parameters(),
            lr=parameters.training.learning_rate)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)
        if self.use_gpu:
            self.reward_network.cuda()
            self.q_network_cpe.cuda()
            self.q_network_cpe_target.cuda()
            if use_all_avail_gpus:
                self.reward_network = torch.nn.DataParallel(
                    self.reward_network)
                self.q_network_cpe = torch.nn.DataParallel(self.q_network_cpe)
                self.q_network_cpe_target = torch.nn.DataParallel(
                    self.q_network_cpe_target)

    @property
    def num_actions(self) -> int:
        return len(self._actions)

    def calculate_q_values(self, states):
        return self.q_network(states).detach()

    def calculate_metric_q_values(self, states):
        return self.q_network_cpe(states).detach()

    def get_detached_q_values(self,
                              states) -> Tuple[torch.Tensor, torch.Tensor]:
        with torch.no_grad():
            q_values = self.q_network(states)
            q_values_target = self.q_network_target(states)
        return q_values, q_values_target

    def get_next_action_q_values(self, states, next_actions):
        """
        Used in SARSA update.
        :param states: Numpy array with shape (batch_size, state_dim). Each row
            contains a representation of a state.
        :param next_actions: Numpy array with shape (batch_size, action_dim).
        """
        q_values = self.q_network_target(states).detach()
        # Max-q action indexes used in CPE
        max_q_values, max_indicies = torch.max(q_values, dim=1, keepdim=True)
        return (torch.sum(q_values * next_actions, dim=1,
                          keepdim=True), max_indicies)

    def train(self, training_samples: TrainingDataPage):

        if self.minibatch == 0:
            # Assume that the tensors are the right shape after the first minibatch
            assert (training_samples.states.shape[0] == self.minibatch_size
                    ), "Invalid shape: " + str(training_samples.states.shape)
            assert training_samples.actions.shape == torch.Size([
                self.minibatch_size, len(self._actions)
            ]), "Invalid shape: " + str(training_samples.actions.shape)
            assert training_samples.rewards.shape == torch.Size(
                [self.minibatch_size,
                 1]), "Invalid shape: " + str(training_samples.rewards.shape)
            assert (training_samples.next_states.shape ==
                    training_samples.states.shape), "Invalid shape: " + str(
                        training_samples.next_states.shape)
            assert (training_samples.not_terminal.shape ==
                    training_samples.rewards.shape), "Invalid shape: " + str(
                        training_samples.not_terminal.shape)
            if training_samples.possible_next_actions_mask is not None:
                assert (
                    training_samples.possible_next_actions_mask.shape ==
                    training_samples.actions.shape), (
                        "Invalid shape: " +
                        str(training_samples.possible_next_actions_mask.shape))
            if training_samples.propensities is not None:
                assert (training_samples.propensities.shape == training_samples
                        .rewards.shape), "Invalid shape: " + str(
                            training_samples.propensities.shape)
            if training_samples.metrics is not None:
                assert (
                    training_samples.metrics.shape[0] == self.minibatch_size
                ), "Invalid shape: " + str(training_samples.metrics.shape)

        boosted_rewards = self.boost_rewards(training_samples.rewards,
                                             training_samples.actions)

        self.minibatch += 1
        states = training_samples.states.detach().requires_grad_(True)
        actions = training_samples.actions
        rewards = boosted_rewards
        discount_tensor = torch.full(training_samples.time_diffs.shape,
                                     self.gamma).type(self.dtype)
        not_done_mask = training_samples.not_terminal

        if self.use_seq_num_diff_as_time_diff:
            time_diff = training_samples.time_diffs / self.time_diff_unit_length
            discount_tensor = discount_tensor.pow(time_diff)

        all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
            training_samples.next_states)

        if self.bcq:
            # Batch constrained q-learning
            on_policy_actions = self.bcq_imitator(training_samples.next_states)
            on_policy_action_probs = softmax(on_policy_actions, temperature=1)
            filter_values = (
                on_policy_action_probs /
                on_policy_action_probs.max(keepdim=True, dim=1)[0])
            action_on_policy = (filter_values >=
                                self.bcq_drop_threshold).float()
            training_samples.possible_next_actions_mask *= action_on_policy
        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                training_samples.possible_next_actions_mask,
            )
        else:
            # SARSA
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                training_samples.next_actions,
            )

        filtered_next_q_vals = next_q_values * not_done_mask

        if self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            target_q_values = rewards + (discount_tensor *
                                         filtered_next_q_vals)

        # Get Q-value of action taken
        all_q_values = self.q_network(states)
        self.all_action_scores = all_q_values.detach()
        q_values = torch.sum(all_q_values * actions, 1, keepdim=True)

        loss = self.q_network_loss(q_values, target_q_values)
        self.loss = loss.detach()

        self.q_network_optimizer.zero_grad()
        loss.backward()
        if self.gradient_handler:
            self.gradient_handler(self.q_network.parameters())
        if self.clip_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(self.q_network.parameters(),
                                           self.clip_grad_norm)
        self.q_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        bcq_loss = None
        if self.bcq:
            # Batch constrained q-learning
            action_preds = self.bcq_imitator(states)
            imitator_loss = torch.nn.CrossEntropyLoss()
            # Classification label is index of action with value 1
            bcq_loss = imitator_loss(action_preds,
                                     torch.max(actions, dim=1)[1])
            self.bcq_imitator_optimizer.zero_grad()
            bcq_loss.backward()
            self.bcq_imitator_optimizer.step()

        logged_action_idxs = actions.argmax(dim=1, keepdim=True)
        reward_loss, model_rewards, model_propensities = self.calculate_cpes(
            training_samples,
            states,
            logged_action_idxs,
            max_q_action_idxs,
            discount_tensor,
            not_done_mask,
        )

        self.loss_reporter.report(
            td_loss=self.loss,
            imitator_loss=bcq_loss,
            reward_loss=reward_loss,
            logged_actions=logged_action_idxs,
            logged_propensities=training_samples.propensities,
            logged_rewards=rewards,
            logged_values=None,  # Compute at end of each epoch for CPE
            model_propensities=model_propensities,
            model_rewards=model_rewards,
            model_values=self.all_action_scores,
            model_values_on_logged_actions=
            None,  # Compute at end of each epoch for CPE
            model_action_idxs=self.get_max_q_values(
                self.all_action_scores,
                training_samples.possible_actions_mask)[1],
        )

    def calculate_cpes(
        self,
        training_samples,
        states,
        logged_action_idxs,
        max_q_action_idxs,
        discount_tensor,
        not_done_mask,
    ):
        if not self.calc_cpe_in_training:
            return None, None, None

        if training_samples.metrics is None:
            metrics_reward_concat_real_vals = training_samples.rewards
        else:
            metrics_reward_concat_real_vals = torch.cat(
                (training_samples.rewards, training_samples.metrics), dim=1)

        ######### Train separate reward network for CPE evaluation #############
        reward_estimates = self.reward_network(states)
        reward_estimates_for_logged_actions = reward_estimates.gather(
            1, self.reward_idx_offsets + logged_action_idxs)
        reward_loss = F.mse_loss(reward_estimates_for_logged_actions,
                                 metrics_reward_concat_real_vals)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        ######### Train separate q-network for CPE evaluation #############
        metric_q_values = self.q_network_cpe(states).gather(
            1, self.reward_idx_offsets + logged_action_idxs)
        metric_target_q_values = self.q_network_cpe_target(states).detach()
        max_q_values_metrics = metric_target_q_values.gather(
            1, self.reward_idx_offsets + max_q_action_idxs)
        filtered_max_q_values_metrics = max_q_values_metrics * not_done_mask
        if self.minibatch < self.reward_burnin:
            target_metric_q_values = metrics_reward_concat_real_vals
        else:
            target_metric_q_values = metrics_reward_concat_real_vals + (
                discount_tensor * filtered_max_q_values_metrics)
        metric_q_value_loss = self.q_network_loss(metric_q_values,
                                                  target_metric_q_values)
        self.q_network_cpe.zero_grad()
        metric_q_value_loss.backward()
        self.q_network_cpe_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network_cpe, self.q_network_cpe_target,
                              1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network_cpe, self.q_network_cpe_target,
                              self.tau)

        model_propensities = masked_softmax(
            self.all_action_scores,
            training_samples.possible_actions_mask,
            self.rl_temperature,
        )
        model_rewards = reward_estimates[:,
                                         torch.arange(
                                             self.reward_idx_offsets[0],
                                             self.reward_idx_offsets[0] +
                                             self.num_actions,
                                         ), ]
        return reward_loss, model_rewards, model_propensities

    def boost_rewards(self, rewards: torch.Tensor,
                      actions: torch.Tensor) -> torch.Tensor:
        # Apply reward boost if specified
        reward_boosts = torch.sum(actions.float() * self.reward_boosts,
                                  dim=1,
                                  keepdim=True)
        return rewards + reward_boosts

    def predictor(self, set_missing_value_to_zero=False) -> DQNPredictor:
        """Builds a DQNPredictor."""
        return DQNPredictor.export(
            self,
            self._actions,
            self.state_normalization_parameters,
            self.use_gpu,
            set_missing_value_to_zero=set_missing_value_to_zero,
        )

    def export(self) -> DQNPredictor:
        return self.predictor()
Esempio n. 17
0
    def __init__(
        self,
        parameters: ContinuousActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        additional_feature_types: AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        metrics_to_score=None,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters
        self.num_state_features = get_num_output_features(
            state_normalization_parameters
        )
        self.num_action_features = get_num_output_features(
            action_normalization_parameters
        )
        self.num_features = self.num_state_features + self.num_action_features

        # ensure state and action IDs have no intersection
        overlapping_features = set(state_normalization_parameters.keys()) & set(
            action_normalization_parameters.keys()
        )
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: "
            + str(overlapping_features)
        )

        reward_network_layers = deepcopy(parameters.training.layers)
        reward_network_layers[0] = self.num_features
        reward_network_layers[-1] = 1

        if parameters.rainbow.dueling_architecture:
            parameters.training.layers[0] = self.num_state_features
            parameters.training.layers[-1] = 1
        elif parameters.training.factorization_parameters is None:
            parameters.training.layers[0] = self.num_features
            parameters.training.layers[-1] = 1
        else:
            parameters.training.factorization_parameters.state.layers[
                0
            ] = self.num_state_features
            parameters.training.factorization_parameters.action.layers[
                0
            ] = self.num_action_features

        RLTrainer.__init__(
            self,
            parameters,
            use_gpu,
            additional_feature_types,
            metrics_to_score,
            gradient_handler,
        )

        self.q_network = self._get_model(
            parameters.training, parameters.rainbow.dueling_architecture
        )

        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        self.reward_network = FullyConnectedNetwork(
            reward_network_layers, parameters.training.activations
        )
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(), lr=parameters.training.learning_rate
        )

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(self.q_network_target)
                self.reward_network = torch.nn.DataParallel(self.reward_network)
Esempio n. 18
0
class ParametricDQNTrainer(DQNTrainerBase):
    def __init__(
        self,
        parameters: ContinuousActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        additional_feature_types: AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        metrics_to_score=None,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters
        self.num_state_features = get_num_output_features(
            state_normalization_parameters
        )
        self.num_action_features = get_num_output_features(
            action_normalization_parameters
        )
        self.num_features = self.num_state_features + self.num_action_features

        # ensure state and action IDs have no intersection
        overlapping_features = set(state_normalization_parameters.keys()) & set(
            action_normalization_parameters.keys()
        )
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: "
            + str(overlapping_features)
        )

        reward_network_layers = deepcopy(parameters.training.layers)
        reward_network_layers[0] = self.num_features
        reward_network_layers[-1] = 1

        if parameters.rainbow.dueling_architecture:
            parameters.training.layers[0] = self.num_state_features
            parameters.training.layers[-1] = 1
        elif parameters.training.factorization_parameters is None:
            parameters.training.layers[0] = self.num_features
            parameters.training.layers[-1] = 1
        else:
            parameters.training.factorization_parameters.state.layers[
                0
            ] = self.num_state_features
            parameters.training.factorization_parameters.action.layers[
                0
            ] = self.num_action_features

        RLTrainer.__init__(
            self,
            parameters,
            use_gpu,
            additional_feature_types,
            metrics_to_score,
            gradient_handler,
        )

        self.q_network = self._get_model(
            parameters.training, parameters.rainbow.dueling_architecture
        )

        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        self.reward_network = FullyConnectedNetwork(
            reward_network_layers, parameters.training.activations
        )
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(), lr=parameters.training.learning_rate
        )

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(self.q_network_target)
                self.reward_network = torch.nn.DataParallel(self.reward_network)

    def _get_model(self, training_parameters, dueling_architecture=False):
        if dueling_architecture:
            return DuelingQNetwork(
                training_parameters.layers,
                training_parameters.activations,
                action_dim=self.num_action_features,
            )
        elif training_parameters.factorization_parameters is None:
            return FullyConnectedNetwork(
                training_parameters.layers,
                training_parameters.activations,
                use_noisy_linear_layers=training_parameters.use_noisy_linear_layers,
            )
        else:
            return ParametricInnerProduct(
                FullyConnectedNetwork(
                    training_parameters.factorization_parameters.state.layers,
                    training_parameters.factorization_parameters.state.activations,
                ),
                FullyConnectedNetwork(
                    training_parameters.factorization_parameters.action.layers,
                    training_parameters.factorization_parameters.action.activations,
                ),
                self.num_state_features,
                self.num_action_features,
            )

    def get_detached_q_values(
        self, state_action_pairs
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """ Gets the q values from the model and target networks """
        with torch.no_grad():
            q_values = self.q_network(state_action_pairs)
            q_values_target = self.q_network_target(state_action_pairs)
        return q_values, q_values_target

    def train(self, training_samples: TrainingDataPage) -> None:
        if self.minibatch == 0:
            # Assume that the tensors are the right shape after the first minibatch
            max_num_actions = training_samples.possible_next_actions_mask.shape[1]

            assert (
                training_samples.states.shape[0] == self.minibatch_size
            ), "Invalid shape: " + str(training_samples.states.shape)
            assert (
                training_samples.next_states.shape == training_samples.states.shape
            ), "Invalid shape: " + str(training_samples.next_states.shape)
            assert (
                training_samples.not_terminal.shape == training_samples.rewards.shape
            ), "Invalid shape: " + str(training_samples.not_terminal.shape)

            assert (
                training_samples.actions.shape[0] == self.minibatch_size
            ), "Invalid shape: " + str(training_samples.actions.shape)
            assert (
                training_samples.possible_next_actions_mask.shape[0]
                == self.minibatch_size
            ), "Invalid shape: " + str(
                training_samples.possible_next_actions_mask.shape
            )
            assert training_samples.actions.shape[1] == self.num_action_features, (
                "Invalid shape: "
                + str(training_samples.actions.shape[1])
                + " != "
                + str(self.num_action_features)
            )

            assert (
                training_samples.possible_next_actions_state_concat.shape[0]
                == self.minibatch_size * max_num_actions
            ), (
                "Invalid shape: "
                + str(training_samples.possible_next_actions_state_concat.shape)
                + " != "
                + str(self.minibatch_size * max_num_actions)
            )

        self.minibatch += 1

        states = training_samples.states.detach().requires_grad_(True)
        actions = training_samples.actions
        state_action_pairs = torch.cat((states, actions), dim=1)

        rewards = training_samples.rewards
        discount_tensor = torch.full(
            training_samples.time_diffs.shape, self.gamma
        ).type(self.dtype)
        not_done_mask = training_samples.not_terminal

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(training_samples.time_diffs)

        if self.maxq_learning:
            all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
                training_samples.possible_next_actions_state_concat
            )
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, _ = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                training_samples.possible_next_actions_mask,
            )
        else:
            # SARSA
            next_q_values, _ = self.get_detached_q_values(
                torch.cat(
                    (training_samples.next_states, training_samples.next_actions), dim=1
                )
            )

        assert next_q_values.shape == not_done_mask.shape, (
            "Invalid shapes: "
            + str(next_q_values.shape)
            + " != "
            + str(not_done_mask.shape)
        )
        filtered_max_q_vals = next_q_values * not_done_mask

        if self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            assert discount_tensor.shape == filtered_max_q_vals.shape, (
                "Invalid shapes: "
                + str(discount_tensor.shape)
                + " != "
                + str(filtered_max_q_vals.shape)
            )
            target_q_values = rewards + (discount_tensor * filtered_max_q_vals)

        # Get Q-value of action taken
        q_values = self.q_network(state_action_pairs)
        all_action_scores = q_values.detach()
        self.model_values_on_logged_actions = q_values.detach()

        value_loss = self.q_network_loss(q_values, target_q_values)
        self.loss = value_loss.detach()

        self.q_network_optimizer.zero_grad()
        value_loss.backward()
        if self.gradient_handler:
            self.gradient_handler(self.q_network.parameters())
        self.q_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        # get reward estimates
        reward_estimates = self.reward_network(state_action_pairs)
        reward_loss = F.mse_loss(reward_estimates, rewards)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            model_values_on_logged_actions=all_action_scores,
        )

    def predictor(self) -> ParametricDQNPredictor:
        """Builds a ParametricDQNPredictor."""
        return ParametricDQNPredictor.export(
            self,
            self.state_normalization_parameters,
            self.action_normalization_parameters,
            self._additional_feature_types.int_features,
            self.use_gpu,
        )

    def export(self) -> ParametricDQNPredictor:
        return self.predictor()
Esempio n. 19
0
    def __init__(
        self,
        parameters: DiscreteActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        metrics_to_score=None,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self._actions = parameters.actions if parameters.actions is not None else []

        if parameters.training.cnn_parameters is None:
            self.state_normalization_parameters: Optional[Dict[
                int, NormalizationParameters]] = state_normalization_parameters
            self.num_features = get_num_output_features(
                state_normalization_parameters)
            logger.info("Number of state features: " + str(self.num_features))
            parameters.training.layers[0] = self.num_features
        else:
            self.state_normalization_parameters = None
        parameters.training.layers[-1] = self.num_actions

        RLTrainer.__init__(
            self,
            parameters,
            use_gpu,
            additional_feature_types,
            metrics_to_score,
            gradient_handler,
            actions=self._actions,
        )

        self.reward_boosts = torch.zeros([1,
                                          len(self._actions)]).type(self.dtype)
        if parameters.rl.reward_boost is not None:
            for k in parameters.rl.reward_boost.keys():
                i = self._actions.index(k)
                self.reward_boosts[0, i] = parameters.rl.reward_boost[k]

        if parameters.rainbow.dueling_architecture:
            self.q_network = DuelingQNetwork(
                parameters.training.layers,
                parameters.training.activations,
                use_batch_norm=parameters.training.use_batch_norm,
            )
        else:
            if parameters.training.cnn_parameters is None:
                self.q_network = FullyConnectedNetwork(
                    parameters.training.layers,
                    parameters.training.activations,
                    use_noisy_linear_layers=parameters.training.
                    use_noisy_linear_layers,
                    min_std=parameters.training.weight_init_min_std,
                    use_batch_norm=parameters.training.use_batch_norm,
                )
            else:
                self.q_network = ConvolutionalNetwork(
                    parameters.training.cnn_parameters,
                    parameters.training.layers,
                    parameters.training.activations,
                    use_noisy_linear_layers=parameters.training.
                    use_noisy_linear_layers,
                    min_std=parameters.training.weight_init_min_std,
                    use_batch_norm=parameters.training.use_batch_norm,
                )

        self.q_network_target = deepcopy(self.q_network)
        self.q_network._name = "training"
        self.q_network_target._name = "target"
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        self._init_cpe_networks(parameters, use_all_avail_gpus)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(
                    self.q_network_target)
Esempio n. 20
0
class ParametricDQNTrainer(DQNTrainerBase):
    def __init__(
        self,
        parameters: ContinuousActionModelParameters,
        state_normalization_parameters: Dict[int, NormalizationParameters],
        action_normalization_parameters: Dict[int, NormalizationParameters],
        use_gpu: bool = False,
        additional_feature_types:
        AdditionalFeatureTypes = DEFAULT_ADDITIONAL_FEATURE_TYPES,
        metrics_to_score=None,
        gradient_handler=None,
        use_all_avail_gpus: bool = False,
    ) -> None:

        self.double_q_learning = parameters.rainbow.double_q_learning
        self.warm_start_model_path = parameters.training.warm_start_model_path
        self.minibatch_size = parameters.training.minibatch_size
        self.state_normalization_parameters = state_normalization_parameters
        self.action_normalization_parameters = action_normalization_parameters
        self.num_state_features = get_num_output_features(
            state_normalization_parameters)
        self.num_action_features = get_num_output_features(
            action_normalization_parameters)
        self.num_features = self.num_state_features + self.num_action_features

        # ensure state and action IDs have no intersection
        overlapping_features = set(
            state_normalization_parameters.keys()) & set(
                action_normalization_parameters.keys())
        assert len(overlapping_features) == 0, (
            "There are some overlapping state and action features: " +
            str(overlapping_features))

        reward_network_layers = deepcopy(parameters.training.layers)
        reward_network_layers[0] = self.num_features
        reward_network_layers[-1] = 1

        if parameters.rainbow.dueling_architecture:
            parameters.training.layers[0] = self.num_state_features
            parameters.training.layers[-1] = 1
        elif parameters.training.factorization_parameters is None:
            parameters.training.layers[0] = self.num_features
            parameters.training.layers[-1] = 1
        else:
            parameters.training.factorization_parameters.state.layers[
                0] = self.num_state_features
            parameters.training.factorization_parameters.action.layers[
                0] = self.num_action_features

        RLTrainer.__init__(
            self,
            parameters,
            use_gpu,
            additional_feature_types,
            metrics_to_score,
            gradient_handler,
        )

        self.q_network = self._get_model(
            parameters.training, parameters.rainbow.dueling_architecture)

        self.q_network_target = deepcopy(self.q_network)
        self._set_optimizer(parameters.training.optimizer)
        self.q_network_optimizer = self.optimizer_func(
            self.q_network.parameters(),
            lr=parameters.training.learning_rate,
            weight_decay=parameters.training.l2_decay,
        )

        self.reward_network = FullyConnectedNetwork(
            reward_network_layers, parameters.training.activations)
        self.reward_network_optimizer = self.optimizer_func(
            self.reward_network.parameters(),
            lr=parameters.training.learning_rate)

        if self.use_gpu:
            self.q_network.cuda()
            self.q_network_target.cuda()
            self.reward_network.cuda()

            if use_all_avail_gpus:
                self.q_network = torch.nn.DataParallel(self.q_network)
                self.q_network_target = torch.nn.DataParallel(
                    self.q_network_target)
                self.reward_network = torch.nn.DataParallel(
                    self.reward_network)

    def _get_model(self, training_parameters, dueling_architecture=False):
        if dueling_architecture:
            return DuelingQNetwork(
                training_parameters.layers,
                training_parameters.activations,
                action_dim=self.num_action_features,
            )
        elif training_parameters.factorization_parameters is None:
            return FullyConnectedNetwork(
                training_parameters.layers,
                training_parameters.activations,
                use_noisy_linear_layers=training_parameters.
                use_noisy_linear_layers,
            )
        else:
            return ParametricInnerProduct(
                FullyConnectedNetwork(
                    training_parameters.factorization_parameters.state.layers,
                    training_parameters.factorization_parameters.state.
                    activations,
                ),
                FullyConnectedNetwork(
                    training_parameters.factorization_parameters.action.layers,
                    training_parameters.factorization_parameters.action.
                    activations,
                ),
                self.num_state_features,
                self.num_action_features,
            )

    def get_detached_q_values(
            self, state_action_pairs) -> Tuple[torch.Tensor, torch.Tensor]:
        """ Gets the q values from the model and target networks """
        with torch.no_grad():
            q_values = self.q_network(state_action_pairs)
            q_values_target = self.q_network_target(state_action_pairs)
        return q_values, q_values_target

    def train(self, training_samples: TrainingDataPage) -> None:
        if self.minibatch == 0:
            # Assume that the tensors are the right shape after the first minibatch
            max_num_actions = training_samples.possible_next_actions_mask.shape[
                1]

            assert (training_samples.states.shape[0] == self.minibatch_size
                    ), "Invalid shape: " + str(training_samples.states.shape)
            assert (training_samples.next_states.shape ==
                    training_samples.states.shape), "Invalid shape: " + str(
                        training_samples.next_states.shape)
            assert (training_samples.not_terminal.shape ==
                    training_samples.rewards.shape), "Invalid shape: " + str(
                        training_samples.not_terminal.shape)

            assert (training_samples.actions.shape[0] == self.minibatch_size
                    ), "Invalid shape: " + str(training_samples.actions.shape)
            assert (training_samples.possible_next_actions_mask.shape[0] ==
                    self.minibatch_size), "Invalid shape: " + str(
                        training_samples.possible_next_actions_mask.shape)
            assert training_samples.actions.shape[
                1] == self.num_action_features, (
                    "Invalid shape: " +
                    str(training_samples.actions.shape[1]) + " != " +
                    str(self.num_action_features))

            assert (
                training_samples.possible_next_actions_state_concat.shape[0] ==
                self.minibatch_size *
                max_num_actions), ("Invalid shape: " + str(
                    training_samples.possible_next_actions_state_concat.shape)
                                   + " != " +
                                   str(self.minibatch_size * max_num_actions))

        self.minibatch += 1

        states = training_samples.states.detach().requires_grad_(True)
        actions = training_samples.actions
        state_action_pairs = torch.cat((states, actions), dim=1)

        rewards = training_samples.rewards
        discount_tensor = torch.full(training_samples.time_diffs.shape,
                                     self.gamma).type(self.dtype)
        not_done_mask = training_samples.not_terminal

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(training_samples.time_diffs)

        if self.maxq_learning:
            all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
                training_samples.possible_next_actions_state_concat)
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, _ = self.get_max_q_values(
                all_next_q_values,
                all_next_q_values_target,
                training_samples.possible_next_actions_mask,
            )
        else:
            # SARSA
            next_q_values, _ = self.get_detached_q_values(
                torch.cat((training_samples.next_states,
                           training_samples.next_actions),
                          dim=1))

        assert next_q_values.shape == not_done_mask.shape, (
            "Invalid shapes: " + str(next_q_values.shape) + " != " +
            str(not_done_mask.shape))
        filtered_max_q_vals = next_q_values * not_done_mask

        if self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            assert discount_tensor.shape == filtered_max_q_vals.shape, (
                "Invalid shapes: " + str(discount_tensor.shape) + " != " +
                str(filtered_max_q_vals.shape))
            target_q_values = rewards + (discount_tensor * filtered_max_q_vals)

        # Get Q-value of action taken
        q_values = self.q_network(state_action_pairs)
        all_action_scores = q_values.detach()
        self.model_values_on_logged_actions = q_values.detach()

        value_loss = self.q_network_loss(q_values, target_q_values)
        self.loss = value_loss.detach()

        self.q_network_optimizer.zero_grad()
        value_loss.backward()
        if self.gradient_handler:
            self.gradient_handler(self.q_network.parameters())
        self.q_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        # get reward estimates
        reward_estimates = self.reward_network(state_action_pairs)
        reward_loss = F.mse_loss(reward_estimates, rewards)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            model_values_on_logged_actions=all_action_scores,
        )

    def predictor(self) -> ParametricDQNPredictor:
        """Builds a ParametricDQNPredictor."""
        return ParametricDQNPredictor.export(
            self,
            self.state_normalization_parameters,
            self.action_normalization_parameters,
            self._additional_feature_types.int_features,
            self.use_gpu,
        )

    def export(self) -> ParametricDQNPredictor:
        return self.predictor()
Esempio n. 21
0
def _get_sac_trainer_params(env, sac_model_params, use_gpu):
    state_dim = get_num_output_features(env.normalization)
    action_dim = get_num_output_features(env.normalization_action)
    q1_network = FullyConnectedParametricDQN(
        state_dim,
        action_dim,
        sac_model_params.q_network.layers,
        sac_model_params.q_network.activations,
    )
    q2_network = None
    if sac_model_params.training.use_2_q_functions:
        q2_network = FullyConnectedParametricDQN(
            state_dim,
            action_dim,
            sac_model_params.q_network.layers,
            sac_model_params.q_network.activations,
        )
    value_network = None
    if sac_model_params.training.use_value_network:
        value_network = FullyConnectedNetwork(
            [state_dim] + sac_model_params.value_network.layers + [1],
            sac_model_params.value_network.activations + ["linear"],
        )
    actor_network = GaussianFullyConnectedActor(
        state_dim,
        action_dim,
        sac_model_params.actor_network.layers,
        sac_model_params.actor_network.activations,
    )

    min_action_range_tensor_training = torch.full((1, action_dim), -1 + 1e-6)
    max_action_range_tensor_training = torch.full((1, action_dim), 1 - 1e-6)
    min_action_range_tensor_serving = (
        torch.from_numpy(env.action_space.low).float().unsqueeze(dim=0)
    )
    max_action_range_tensor_serving = (
        torch.from_numpy(env.action_space.high).float().unsqueeze(dim=0)
    )

    if use_gpu:
        q1_network.cuda()
        if q2_network:
            q2_network.cuda()
        if value_network:
            value_network.cuda()
        actor_network.cuda()

        min_action_range_tensor_training = min_action_range_tensor_training.cuda()
        max_action_range_tensor_training = max_action_range_tensor_training.cuda()
        min_action_range_tensor_serving = min_action_range_tensor_serving.cuda()
        max_action_range_tensor_serving = max_action_range_tensor_serving.cuda()

    trainer_args = [q1_network, actor_network, sac_model_params]
    trainer_kwargs = {
        "value_network": value_network,
        "q2_network": q2_network,
        "min_action_range_tensor_training": min_action_range_tensor_training,
        "max_action_range_tensor_training": max_action_range_tensor_training,
        "min_action_range_tensor_serving": min_action_range_tensor_serving,
        "max_action_range_tensor_serving": max_action_range_tensor_serving,
    }
    return trainer_args, trainer_kwargs
Esempio n. 22
0
    def get_sac_trainer(
        self,
        env,
        use_gpu,
        use_2_q_functions=False,
        logged_action_uniform_prior=True,
        constrain_action_sum=False,
        use_value_network=True,
    ):
        q_network_params = FeedForwardParameters(layers=[128, 64],
                                                 activations=["relu", "relu"])
        value_network_params = FeedForwardParameters(
            layers=[128, 64], activations=["relu", "relu"])
        actor_network_params = FeedForwardParameters(
            layers=[128, 64], activations=["relu", "relu"])

        state_dim = get_num_output_features(env.normalization)
        action_dim = get_num_output_features(
            env.normalization_continuous_action)
        q1_network = FullyConnectedParametricDQN(state_dim, action_dim,
                                                 q_network_params.layers,
                                                 q_network_params.activations)
        q2_network = None
        if use_2_q_functions:
            q2_network = FullyConnectedParametricDQN(
                state_dim,
                action_dim,
                q_network_params.layers,
                q_network_params.activations,
            )
        if constrain_action_sum:
            actor_network = DirichletFullyConnectedActor(
                state_dim,
                action_dim,
                actor_network_params.layers,
                actor_network_params.activations,
            )
        else:
            actor_network = GaussianFullyConnectedActor(
                state_dim,
                action_dim,
                actor_network_params.layers,
                actor_network_params.activations,
            )

        value_network = None
        if use_value_network:
            value_network = FullyConnectedNetwork(
                [state_dim] + value_network_params.layers + [1],
                value_network_params.activations + ["linear"],
            )

        if use_gpu:
            q1_network.cuda()
            if q2_network:
                q2_network.cuda()
            if value_network:
                value_network.cuda()
            actor_network.cuda()

        parameters = SACTrainerParameters(
            rl=RLParameters(gamma=DISCOUNT, target_update_rate=0.5),
            minibatch_size=self.minibatch_size,
            q_network_optimizer=OptimizerParameters(),
            value_network_optimizer=OptimizerParameters(),
            actor_network_optimizer=OptimizerParameters(),
            alpha_optimizer=OptimizerParameters(),
            logged_action_uniform_prior=logged_action_uniform_prior,
        )

        return SACTrainer(
            q1_network,
            actor_network,
            parameters,
            use_gpu=use_gpu,
            value_network=value_network,
            q2_network=q2_network,
        )