def map_fn(data): data_shape = T.shape(data) leading = data_shape[:-1] dim_in = data_shape[-1] flattened = T.reshape(data, [-1, dim_in]) net_out = network(flattened) if isinstance(net_out, stats.GaussianScaleDiag): scale_diag, mu = net_out.get_parameters('regular') dim_out = T.shape(mu)[-1] return stats.GaussianScaleDiag([ T.reshape(scale_diag, T.concatenate([leading, [dim_out]])), T.reshape(mu, T.concatenate([leading, [dim_out]])), ]) elif isinstance(net_out, stats.Gaussian): sigma, mu = net_out.get_parameters('regular') dim_out = T.shape(mu)[-1] return stats.Gaussian([ T.reshape(sigma, T.concatenate([leading, [dim_out, dim_out]])), T.reshape(mu, T.concatenate([leading, [dim_out]])), ]) elif isinstance(net_out, stats.Bernoulli): params = net_out.get_parameters('natural') dim_out = T.shape(params)[-1] return stats.Bernoulli( T.reshape(params, T.concatenate([leading, [dim_out]])), 'natural') else: raise Exception("Unimplemented distribution")
def make_summary(self, observations, name): if self.image: observations = T.reshape(observations, [-1] + self.image_size()) T.core.summary.image(name, observations)
def make_summary(self, observations, name): if self.image: observations = T.reshape(observations, [-1] + self.image_size()) T.core.summary.image(name + "-point", observations[..., 0:1]) T.core.summary.image(name + "-goal", observations[..., 1:2])