Esempio n. 1
0
 def forward(self, input):
     q_values = self.q_network(input)
     imitator_outputs = self.imitator_network(input.state.float_features)
     imitator_probs = torch.nn.functional.softmax(imitator_outputs, dim=1)
     filter_values = imitator_probs / imitator_probs.max(keepdim=True,
                                                         dim=1)[0]
     invalid_actions = (filter_values < self.bcq_drop_threshold).float()
     invalid_action_penalty = self.invalid_action_penalty * invalid_actions
     constrained_q_values = q_values.q_values + invalid_action_penalty
     return rlt.AllActionQValues(q_values=constrained_q_values)
Esempio n. 2
0
 def forward(self, input: rlt.PreprocessedState):
     q_values = self.dist(input).mean(dim=2)
     return rlt.AllActionQValues(q_values=q_values)
Esempio n. 3
0
    def forward(self,
                input) -> Union[NamedTuple, torch.FloatTensor]:  # type: ignore
        output_tensor = False
        if self.parametric_action:
            state = input.state.float_features
            action = input.action.float_features
        else:
            state = input.state.float_features
            action = None

        x = state
        for i, activation in enumerate(self.activations[:-1]):
            if self.use_batch_norm:
                x = self.batch_norm_ops[i](x)

            x = self.layers[i](x)
            if activation == "linear":
                continue
            elif activation == "tanh":
                activation_func = torch.tanh
            else:
                activation_func = getattr(F, activation)
            x = activation_func(x)

        value = self.value(x)
        if action is not None:
            x = torch.cat((x, action), dim=1)
        raw_advantage = self.advantage(x)
        if self.parametric_action:
            advantage = raw_advantage
        else:
            advantage = raw_advantage - raw_advantage.mean(dim=1, keepdim=True)

        q_value = value + advantage

        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/value".format(self._name),
                value.detach().cpu())
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_value".format(self._name),
                value.detach().mean().cpu(),
            )
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/q_value".format(self._name),
                q_value.detach().cpu())
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_q_value".format(self._name),
                q_value.detach().mean().cpu(),
            )
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/raw_advantage".format(self._name),
                raw_advantage.detach().cpu(),
            )
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_raw_advantage".format(self._name),
                raw_advantage.detach().mean().cpu(),
            )
            if not self.parametric_action:
                advantage = advantage.detach()
                for i in range(advantage.shape[1]):
                    a = advantage[:, i]
                    SummaryWriterContext.add_histogram(
                        "dueling_network/{}/advantage/{}".format(
                            self._name, i), a.cpu())
                    SummaryWriterContext.add_scalar(
                        "dueling_network/{}/mean_advantage/{}".format(
                            self._name, i),
                        a.mean().cpu(),
                    )

        if output_tensor:
            return q_value  # type: ignore
        elif self.parametric_action:
            return rlt.SingleQValue(q_value=q_value)  # type: ignore
        else:
            return rlt.AllActionQValues(q_values=q_value)  # type: ignore
Esempio n. 4
0
 def forward(self, input: rlt.PreprocessedState):
     dist = self.log_dist(input).exp()
     q_values = (dist * self.support).sum(2)
     return rlt.AllActionQValues(q_values=q_values)
Esempio n. 5
0
 def forward(self, input):
     q_values = self.data_parallel(input.state.float_features)
     return rlt.AllActionQValues(q_values=q_values)
Esempio n. 6
0
 def forward(self, input: rlt.PreprocessedState):
     q_values = self.fc(input.state.float_features)
     return rlt.AllActionQValues(q_values=q_values)