示例#1
0
 def forward(self, input):
     q_values = self.q_network(input)
     imitator_outputs = self.imitator_network(input.state.float_features)
     imitator_probs = torch.nn.functional.softmax(imitator_outputs, dim=1)
     filter_values = imitator_probs / imitator_probs.max(keepdim=True, dim=1)[0]
     invalid_actions = (filter_values < self.bcq_drop_threshold).float()
     invalid_action_penalty = self.invalid_action_penalty * invalid_actions
     constrained_q_values = q_values.q_values + invalid_action_penalty
     return rlt.AllActionQValues(q_values=constrained_q_values)
 def forward(self, input: rlt.PreprocessedState):
     embeddings = [
         m(
             input.state.id_list_features[name][1],
             input.state.id_list_features[name][0],
         ) for name, m in self.embedding_bags.items()
     ]
     fc_input = torch.cat(embeddings + [input.state.float_features], dim=1)
     q_values = self.fc(fc_input)
     return rlt.AllActionQValues(q_values=q_values)
示例#3
0
    def forward(self,
                input) -> Union[NamedTuple, torch.FloatTensor]:  # type: ignore
        output_tensor = False
        if self.parametric_action:
            state = input.state.float_features
            action = input.action.float_features
        else:
            state = input.state.float_features
            action = None

        x = state
        for i, activation in enumerate(self.activations[:-1]):
            if self.use_batch_norm:
                x = self.batch_norm_ops[i](x)

            x = self.layers[i](x)
            if activation == "linear":
                continue
            elif activation == "tanh":
                activation_func = torch.tanh
            else:
                activation_func = getattr(F, activation)
            x = activation_func(x)

        value = self.value(x)
        if action is not None:
            x = torch.cat((x, action), dim=1)
        raw_advantage = self.advantage(x)
        if self.parametric_action:
            advantage = raw_advantage
        else:
            advantage = raw_advantage - raw_advantage.mean(dim=1, keepdim=True)

        q_value = value + advantage

        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/value".format(self._name),
                value.detach().cpu())
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_value".format(self._name),
                value.detach().mean().cpu(),
            )
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/q_value".format(self._name),
                q_value.detach().cpu())
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_q_value".format(self._name),
                q_value.detach().mean().cpu(),
            )
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/raw_advantage".format(self._name),
                raw_advantage.detach().cpu(),
            )
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_raw_advantage".format(self._name),
                raw_advantage.detach().mean().cpu(),
            )
            if not self.parametric_action:
                for i in range(advantage.shape[1]):
                    a = advantage.detach()[:, i]
                    SummaryWriterContext.add_histogram(
                        "dueling_network/{}/advantage/{}".format(
                            self._name, i), a.cpu())
                    SummaryWriterContext.add_scalar(
                        "dueling_network/{}/mean_advantage/{}".format(
                            self._name, i),
                        a.mean().cpu(),
                    )

        if output_tensor:
            return q_value  # type: ignore
        elif self.parametric_action:
            return rlt.SingleQValue(q_value=q_value)  # type: ignore
        else:
            return rlt.AllActionQValues(q_values=q_value)  # type: ignore
示例#4
0
文件: dqn.py 项目: xaxis-code/Horizon
 def forward(self, input):
     q_values = self.data_parallel(input.state.float_features)
     return rlt.AllActionQValues(q_values=q_values)
示例#5
0
 def forward(self, input: rlt.PreprocessedState):
     dist = self.log_dist(input).exp()
     q_values = (dist * self.support).sum(2)
     return rlt.AllActionQValues(q_values=q_values)
示例#6
0
 def forward(self, input: rlt.PreprocessedState):
     q_values = self.fc(input.state.float_features)
     return rlt.AllActionQValues(q_values=q_values)
示例#7
0
 def forward(self, input: rlt.PreprocessedState):
     q_values = self.dist(input).mean(dim=2)
     return rlt.AllActionQValues(q_values=q_values)
示例#8
0
    def forward(self, input) -> torch.FloatTensor:
        output_tensor = False
        if isinstance(input, torch.Tensor):
            # Maintaining backward compatibility for a bit
            state_dim = self.layers[0].in_features
            state = input[:, :state_dim]
            action = input[:, state_dim:]
            output_tensor = True
        elif self.parametric_action:
            state = input.state.float_features
            action = input.action.float_features
        else:
            state = input.state.float_features
            action = None

        x = state
        for i, activation in enumerate(self.activations[:-1]):
            if self.use_batch_norm:
                x = self.batch_norm_ops[i](x)
            activation_func = getattr(F, activation)
            fc_func = self.layers[i]
            x = fc_func(x) if activation == "linear" else activation_func(fc_func(x))

        value = self.value(x)
        if action is not None:
            x = torch.cat((x, action), dim=1)
        raw_advantage = self.advantage(x)
        if self.parametric_action:
            advantage = raw_advantage
        else:
            advantage = raw_advantage - raw_advantage.mean(dim=1, keepdim=True)

        q_value = value + advantage

        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/value".format(self._name), value.detach().cpu()
            )
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_value".format(self._name),
                value.detach().mean().cpu(),
            )
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/q_value".format(self._name), q_value.detach().cpu()
            )
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_q_value".format(self._name),
                q_value.detach().mean().cpu(),
            )
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/raw_advantage".format(self._name),
                raw_advantage.detach().cpu(),
            )
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_raw_advantage".format(self._name),
                raw_advantage.detach().mean().cpu(),
            )
            if not self.parametric_action:
                for i in range(advantage.shape[1]):
                    a = advantage.detach()[:, i]
                    SummaryWriterContext.add_histogram(
                        "dueling_network/{}/advatage/{}".format(self._name, i), a.cpu()
                    )
                    SummaryWriterContext.add_scalar(
                        "dueling_network/{}/mean_advatage/{}".format(self._name, i),
                        a.mean().cpu(),
                    )

        if output_tensor:
            return q_value
        elif self.parametric_action:
            return rlt.SingleQValue(q_value=q_value)
        else:
            return rlt.AllActionQValues(q_values=q_value)