コード例 #1
0
ファイル: nn.py プロジェクト: yuchen8807/parasol
 def map_fn(data):
     data_shape = T.shape(data)
     leading = data_shape[:-1]
     dim_in = data_shape[-1]
     flattened = T.reshape(data, [-1, dim_in])
     net_out = network(flattened)
     if isinstance(net_out, stats.GaussianScaleDiag):
         scale_diag, mu = net_out.get_parameters('regular')
         dim_out = T.shape(mu)[-1]
         return stats.GaussianScaleDiag([
             T.reshape(scale_diag, T.concatenate([leading, [dim_out]])),
             T.reshape(mu, T.concatenate([leading, [dim_out]])),
         ])
     elif isinstance(net_out, stats.Gaussian):
         sigma, mu = net_out.get_parameters('regular')
         dim_out = T.shape(mu)[-1]
         return stats.Gaussian([
             T.reshape(sigma, T.concatenate([leading, [dim_out, dim_out]])),
             T.reshape(mu, T.concatenate([leading, [dim_out]])),
         ])
     elif isinstance(net_out, stats.Bernoulli):
         params = net_out.get_parameters('natural')
         dim_out = T.shape(params)[-1]
         return stats.Bernoulli(
             T.reshape(params, T.concatenate([leading, [dim_out]])),
             'natural')
     else:
         raise Exception("Unimplemented distribution")
コード例 #2
0
ファイル: car.py プロジェクト: yuchen8807/parasol
 def make_summary(self, observations, name):
     if self.image:
         observations = T.reshape(observations, [-1] + self.image_size())
         T.core.summary.image(name, observations)
コード例 #3
0
ファイル: pointmass.py プロジェクト: yuchen8807/parasol
 def make_summary(self, observations, name):
     if self.image:
         observations = T.reshape(observations, [-1] + self.image_size())
         T.core.summary.image(name + "-point", observations[..., 0:1])
         T.core.summary.image(name + "-goal", observations[..., 1:2])