コード例 #1
0
ファイル: ddpg_trainer.py プロジェクト: xavierzw/Horizon
 def forward(self, input: rlt.StateAction) -> rlt.SingleQValue:
     """ Forward pass for critic network. Assumes activation names are
     valid pytorch activation names.
     :param input ml.rl.types.StateAction of combined states and actions
     """
     return rlt.SingleQValue(
         q_value=self.network.forward(
             [input.state.float_features, input.action.float_features]
         )
     )
コード例 #2
0
ファイル: dueling_q_network.py プロジェクト: joshrose/Horizon
    def forward(self,
                input) -> Union[NamedTuple, torch.FloatTensor]:  # type: ignore
        output_tensor = False
        if self.parametric_action:
            state = input.state.float_features
            action = input.action.float_features
        else:
            state = input.state.float_features
            action = None

        x = state
        for i, activation in enumerate(self.activations[:-1]):
            if self.use_batch_norm:
                x = self.batch_norm_ops[i](x)

            x = self.layers[i](x)
            if activation == "linear":
                continue
            elif activation == "tanh":
                activation_func = torch.tanh
            else:
                activation_func = getattr(F, activation)
            x = activation_func(x)

        value = self.value(x)
        if action is not None:
            x = torch.cat((x, action), dim=1)
        raw_advantage = self.advantage(x)
        if self.parametric_action:
            advantage = raw_advantage
        else:
            advantage = raw_advantage - raw_advantage.mean(dim=1, keepdim=True)

        q_value = value + advantage

        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/value".format(self._name),
                value.detach().cpu())
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_value".format(self._name),
                value.detach().mean().cpu(),
            )
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/q_value".format(self._name),
                q_value.detach().cpu())
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_q_value".format(self._name),
                q_value.detach().mean().cpu(),
            )
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/raw_advantage".format(self._name),
                raw_advantage.detach().cpu(),
            )
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_raw_advantage".format(self._name),
                raw_advantage.detach().mean().cpu(),
            )
            if not self.parametric_action:
                for i in range(advantage.shape[1]):
                    a = advantage.detach()[:, i]
                    SummaryWriterContext.add_histogram(
                        "dueling_network/{}/advantage/{}".format(
                            self._name, i), a.cpu())
                    SummaryWriterContext.add_scalar(
                        "dueling_network/{}/mean_advantage/{}".format(
                            self._name, i),
                        a.mean().cpu(),
                    )

        if output_tensor:
            return q_value  # type: ignore
        elif self.parametric_action:
            return rlt.SingleQValue(q_value=q_value)  # type: ignore
        else:
            return rlt.AllActionQValues(q_values=q_value)  # type: ignore
コード例 #3
0
ファイル: parametric_dqn.py プロジェクト: ananthc/ReAgent
 def forward(self, input):
     cat_input = torch.cat(
         (input.state.float_features, input.action.float_features), dim=1)
     q_value = self.fc(cat_input)
     return rlt.SingleQValue(q_value=q_value)
コード例 #4
0
    def forward(self, input) -> torch.FloatTensor:
        output_tensor = False
        if isinstance(input, torch.Tensor):
            # Maintaining backward compatibility for a bit
            state_dim = self.layers[0].in_features
            state = input[:, :state_dim]
            action = input[:, state_dim:]
            output_tensor = True
        elif self.parametric_action:
            state = input.state.float_features
            action = input.action.float_features
        else:
            state = input.state.float_features
            action = None

        x = state
        for i, activation in enumerate(self.activations[:-1]):
            if self.use_batch_norm:
                x = self.batch_norm_ops[i](x)
            activation_func = getattr(F, activation)
            fc_func = self.layers[i]
            x = fc_func(x) if activation == "linear" else activation_func(fc_func(x))

        value = self.value(x)
        if action is not None:
            x = torch.cat((x, action), dim=1)
        raw_advantage = self.advantage(x)
        if self.parametric_action:
            advantage = raw_advantage
        else:
            advantage = raw_advantage - raw_advantage.mean(dim=1, keepdim=True)

        q_value = value + advantage

        if SummaryWriterContext._global_step % 1000 == 0:
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/value".format(self._name), value.detach().cpu()
            )
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_value".format(self._name),
                value.detach().mean().cpu(),
            )
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/q_value".format(self._name), q_value.detach().cpu()
            )
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_q_value".format(self._name),
                q_value.detach().mean().cpu(),
            )
            SummaryWriterContext.add_histogram(
                "dueling_network/{}/raw_advantage".format(self._name),
                raw_advantage.detach().cpu(),
            )
            SummaryWriterContext.add_scalar(
                "dueling_network/{}/mean_raw_advantage".format(self._name),
                raw_advantage.detach().mean().cpu(),
            )
            if not self.parametric_action:
                for i in range(advantage.shape[1]):
                    a = advantage.detach()[:, i]
                    SummaryWriterContext.add_histogram(
                        "dueling_network/{}/advatage/{}".format(self._name, i), a.cpu()
                    )
                    SummaryWriterContext.add_scalar(
                        "dueling_network/{}/mean_advatage/{}".format(self._name, i),
                        a.mean().cpu(),
                    )

        if output_tensor:
            return q_value
        elif self.parametric_action:
            return rlt.SingleQValue(q_value=q_value)
        else:
            return rlt.AllActionQValues(q_values=q_value)