示例#1
0
    def check_forward(self, diag_data, non_diag_data):
        diag = chainer.Variable(diag_data)
        non_diag = chainer.Variable(non_diag_data)
        y = lower_triangular_matrix(diag, non_diag)

        correct_y = numpy.zeros((self.batch_size, self.n, self.n),
                                dtype=numpy.float32)

        tril_rows, tril_cols = numpy.tril_indices(self.n, -1)
        correct_y[:, tril_rows, tril_cols] = cuda.to_cpu(non_diag_data)

        diag_rows, diag_cols = numpy.diag_indices(self.n)
        correct_y[:, diag_rows, diag_cols] = cuda.to_cpu(diag_data)

        gradient_check.assert_allclose(correct_y, cuda.to_cpu(y.data))
示例#2
0
    def __call__(self, state, test=False):
        h = self.hidden_layers(state, test=test)
        v = self.v(h)
        mu = self.mu(h)

        if self.scale_mu:
            mu = scale_by_tanh(mu, high=self.action_space.high,
                               low=self.action_space.low)

        mat_diag = F.exp(self.mat_diag(h))
        if hasattr(self, 'mat_non_diag'):
            mat_non_diag = self.mat_non_diag(h)
            tril = lower_triangular_matrix(mat_diag, mat_non_diag)
            mat = F.batch_matmul(tril, tril, transb=True)
        else:
            mat = F.expand_dims(mat_diag ** 2, axis=2)
        return QuadraticActionValue(
            mu, mat, v, min_action=self.action_space.low,
            max_action=self.action_space.high)
示例#3
0
 def __call__(self, s):
     if self._use_batch_norm:
         h = self._L_bn0(s)
     else:
         h = s
     h = self._linear_L1(h)
     if self._use_batch_norm:
         h = self._L_bn1(h)
     else:
         pass
     h = F.relu(h)
     h = self._linear_L2(h)
     if self._use_batch_norm:
         h = self._L_bn2(h)
     else:
         pass
     h = F.relu(h)
     diag = F.exp(self._linear_L3_diag(h))
     rest = self._linear_L3_rest(h)
     return lower_triangular_matrix(diag, rest)
 def _L_matrix(self, x):
     diag = F.exp(self._linear_L_diag(x))
     rest = self._linear_L_rest(x)
     return lower_triangular_matrix(diag, rest)