def __init__(self,
                 n_actions,
                 n_input_channels=4,
                 activation=F.relu,
                 bias=0.1,
                 reward_boundaries=None,
                 reward_channel_scale=1.):
        self.n_actions = n_actions
        self.n_input_channels = n_input_channels
        self.activation = activation
        self.boundaries = torch.from_numpy(
            np.array(reward_boundaries)) * reward_channel_scale - 1e-8

        super().__init__()
        self.conv_layers = nn.ModuleList([
            nn.Conv2d(n_input_channels, 32, 8, stride=4),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.Conv2d(64, 64, 3, stride=1),
        ])

        # Modified from 3136 -> 1024
        self.a_streams = nn.ModuleList([
            MLP(1024, n_actions, [512])
            for _ in range(len(self.boundaries) + 1)
        ])
        self.v_streams = nn.ModuleList(
            [MLP(1024, 1, [512]) for _ in range(len(self.boundaries) + 1)])

        self.conv_layers.apply(init_chainer_default)  # MLP already applies
        self.conv_layers.apply(constant_bias_initializer(bias=bias))
    def __init__(self,
                 n_actions,
                 n_input_channels=4,
                 activation=F.relu,
                 bias=0.1):
        self.n_actions = n_actions
        self.n_input_channels = n_input_channels
        self.activation = activation

        super().__init__()
        self.conv_layers = nn.ModuleList([
            nn.Conv2d(n_input_channels, 32, 8, stride=4),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.Conv2d(64, 64, 3, stride=1),
        ])

        # Modified from 3136 -> 1024
        self.a_stream = MLP(1024, n_actions, [512])
        self.v_stream = MLP(1024, 1, [512])

        self.conv_layers.apply(init_chainer_default)  # MLP already applies
        self.conv_layers.apply(constant_bias_initializer(bias=bias))
    def __init__(
        self,
        n_actions,
        n_atoms,
        v_min,
        v_max,
        n_input_channels=4,
        activation=torch.relu,
        bias=0.1,
    ):
        assert n_atoms >= 2
        assert v_min < v_max

        self.n_actions = n_actions
        self.n_input_channels = n_input_channels
        self.activation = activation
        self.n_atoms = n_atoms

        super().__init__()
        self.z_values = torch.linspace(v_min,
                                       v_max,
                                       n_atoms,
                                       dtype=torch.float32)

        self.conv_layers = nn.ModuleList([
            nn.Conv2d(n_input_channels, 32, 8, stride=4),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.Conv2d(64, 64, 3, stride=1),
        ])

        # ここだけ変える必要があった
        # self.main_stream = nn.Linear(3136, 1024)
        self.main_stream = nn.Linear(1024, 1024)
        self.a_stream = nn.Linear(512, n_actions * n_atoms)
        self.v_stream = nn.Linear(512, n_atoms)

        self.apply(init_chainer_default)
        self.conv_layers.apply(constant_bias_initializer(bias=bias))