Example #1
0
    def __call__(self, x):
        h = x
        for l in self.conv_layers:
            h = self.activation(l(h))

        # Advantage
        batch_size = x.shape[0]
        ya = self.a_stream(h)
        ya = F.reshape(ya, (batch_size, self.n_actions, self.n_atoms))
        mean = F.reshape(
            F.sum(ya, axis=1) / self.n_actions, (batch_size, 1, self.n_atoms))
        ya, mean = F.broadcast(ya, mean)
        ya -= mean

        # State value
        ys = self.v_stream(h)
        ys = F.reshape(ys, (batch_size, 1, self.n_atoms))

        ya, ys = F.broadcast(ya, ys)
        q = ya + ys

        q = F.reshape(q, (-1, self.n_actions, self.n_atoms))
        q = F.softmax(q, axis=2)

        return action_value.DistributionalDiscreteActionValue(q, self.z_values)
    def __call__(self, x):
        h = x

        for l in self.state_layers:
            h = self.activation(l(h))

        # Advantage
        batch_size = x.shape[0]

        h = self.activation(self.main_stream(h))
        h_a, h_v = F.split_axis(h, 2, axis=-1)
        ya = F.reshape(self.a_stream(h_a),
                       (batch_size, self.n_actions, self.n_atoms))

        mean = F.sum(ya, axis=1, keepdims=True) / self.n_actions

        ya, mean = F.broadcast(ya, mean)
        ya -= mean

        # State value
        ys = F.reshape(self.v_stream(h_v), (batch_size, 1, self.n_atoms))
        ya, ys = F.broadcast(ya, ys)
        q = F.softmax(ya + ys, axis=2)

        return action_value.DistributionalDiscreteActionValue(q, self.z_values)
 def setUp(self):
     self.batch_size = 30
     self.action_size = 3
     self.n_atoms = 51
     self.atom_probs = np.random.dirichlet(
         alpha=np.ones(self.n_atoms),
         size=(self.batch_size, self.action_size)).astype(np.float32)
     self.z_values = np.linspace(
         -10, 10, num=self.n_atoms, dtype=np.float32)
     self.qout = action_value.DistributionalDiscreteActionValue(
         chainer.Variable(self.atom_probs), self.z_values)
     self.q_values = (self.atom_probs * self.z_values).sum(axis=2)