Ejemplo n.º 1
0
    def _build(self, inputs):
        # create the layers for mean and covariance
        
        output_shape = [-1] + self._output_shape
        logits = tf.reshape(snt.Linear(np.prod(self._output_shape), initializers=self._initializers, regularizers=self._regularizers)(inputs),output_shape)

        dtype = inputs.dtype
        if self._dtype is not None:
            dtype = self._dtype

        if self._clip_value > 0:
            probs = tf.nn.sigmoid(logits)

            probs = tf.clip_by_value(probs, self._clip_value, 1 - self._clip_value)
            bernoulli = tfd.Bernoulli(probs=probs, dtype=dtype)
        else:
            bernoulli = tfd.Bernoulli(logits=logits, dtype=dtype)
        
        def reconstruction_node(self):
            return self.mean()
        bernoulli.reconstruction_node = types.MethodType(reconstruction_node, bernoulli)

        def distribution_parameters(self):
            return [self.mean()]
        bernoulli.distribution_parameters = types.MethodType(distribution_parameters, bernoulli)

        def get_probs(self):
            return self.probs

        bernoulli.get_probs = types.MethodType(get_probs, bernoulli)

        return bernoulli
def define_graph(config):
  network_tpl = tf.make_template('network', network, config=config)
  inputs = tf.placeholder(tf.float32, [None, config.num_inputs])
  targets = tf.placeholder(tf.float32, [None, 1])
  num_visible = tf.placeholder(tf.int32, [])
  batch_size = tf.to_float(tf.shape(inputs)[0])
  data_mean, data_noise, data_uncertainty = network_tpl(inputs)
  ood_inputs = inputs + tf.random_normal(
      tf.shape(inputs), 0.0, config.noise_std)
  ood_mean, ood_noise, ood_uncertainty = network_tpl(ood_inputs)
  losses = [
      -tfd.Normal(data_mean, data_noise).log_prob(targets),
      -tfd.Bernoulli(data_uncertainty).log_prob(0),
      -tfd.Bernoulli(ood_uncertainty).log_prob(1),
  ]
  if config.center_at_target:
    losses.append(-tfd.Normal(ood_mean, ood_noise).log_prob(targets))
  loss = sum(tf.reduce_sum(loss) for loss in losses) / batch_size
  optimizer = tf.train.AdamOptimizer(config.learning_rate)
  gradients, variables = zip(*optimizer.compute_gradients(
      loss, colocate_gradients_with_ops=True))
  if config.clip_gradient:
    gradients, _ = tf.clip_by_global_norm(gradients, config.clip_gradient)
  optimize = optimizer.apply_gradients(zip(gradients, variables))
  data_uncertainty = tf.sigmoid(data_uncertainty)
  if not config.center_at_target:
    data_mean = (1 - data_uncertainty) * data_mean + data_uncertainty * 0
  data_noise = (1 - data_uncertainty) * data_noise + data_uncertainty * 0.1
  return tools.AttrDict(locals())
Ejemplo n.º 3
0
 def call(self, inputs, training=None):
     logits = self.logits(inputs)
     q = distributions.Bernoulli(
         logits=logits,
         dtype=tf.float32,
     )
     if self.beta > 0.0:
         kld = q.kl_divergence(
             distributions.Bernoulli(
                 logits=tf.zeros_like(logits),
                 dtype=tf.float32,
             ))
         self.add_loss(tf.reduce_mean(kld))
     return q
Ejemplo n.º 4
0
def get_distributions_from_tensor(t, dimension, num_mixes):
    y_pred = tf.reshape(t, [-1, (2 * num_mixes * dimension + 1) + num_mixes],
                        name='reshape_ypreds')
    out_e, out_pi, out_mus, out_stds = tf.split(y_pred,
                                                num_or_size_splits=[
                                                    1, num_mixes,
                                                    num_mixes * dimension,
                                                    num_mixes * dimension
                                                ],
                                                name='mdn_coef_split',
                                                axis=-1)

    cat = tfd.Categorical(logits=out_pi)
    components_splits = [dimension] * num_mixes
    mus = tf.split(out_mus, num_or_size_splits=components_splits, axis=1)
    stds = tf.split(out_stds, num_or_size_splits=components_splits, axis=1)

    components = [
        tfd.MultivariateNormalDiag(loc=mu_i, scale_diag=std_i)
        for mu_i, std_i in zip(mus, stds)
    ]

    mix = tfd.Mixture(cat=cat, components=components)
    stroke = tfd.Bernoulli(logits=out_e)
    return mix, stroke
Ejemplo n.º 5
0
def generate_2pl_data(n_sample, n_factor, n_item, 
                      alpha, beta, rho, 
                      dtype = tf.float64):
    if (n_item % n_factor) != 0:
        n_item = n_factor * (n_item // n_factor)
    item_per_factor = (n_item // n_factor)
    intercept = tf.fill((n_item,), value = tf.constant(alpha, dtype = dtype))
    loading = np.zeros((n_item, n_factor))
    for i in range(n_factor):
        for j in range(i * item_per_factor,
                       (i + 1) * item_per_factor):
            loading[j, i] = ld
    loading = tf.constant(loading, dtype = dtype)
    if rho is None:
        cor = tf.eye(n_factor, dtype = dtype)
    else:
        unit = tf.ones((n_factor, 1), dtype = dtype)
        identity = tf.eye(n_factor, dtype = dtype)
        cor = rho * (unit @ tf.transpose(unit)) + (1 - rho) * identity
    dist_eta = tfd.MultivariateNormalTriL(
        loc = tf.zeros(n_factor, dtype = dtype), scale_tril = tf.linalg.cholesky(cor))
    eta = dist_eta.sample(n_sample)
    logits = intercept + eta @ tf.transpose(loading)
    x = tfd.Bernoulli(logits=logits, dtype=dtype).sample()
    return x
Ejemplo n.º 6
0
    def output_function(self, state):
        params = dense_layer(state.h3,
                             self.output_units,
                             scope='gmm',
                             reuse=tf.compat.v1.AUTO_REUSE)
        pis, mus, sigmas, rhos, es = self._parse_parameters(params)
        mu1, mu2 = tf.split(mus, 2, axis=1)
        mus = tf.stack([mu1, mu2], axis=2)
        sigma1, sigma2 = tf.split(sigmas, 2, axis=1)

        covar_matrix = [
            tf.square(sigma1), rhos * sigma1 * sigma2, rhos * sigma1 * sigma2,
            tf.square(sigma2)
        ]
        covar_matrix = tf.stack(covar_matrix, axis=2)
        covar_matrix = tf.reshape(
            covar_matrix,
            (self.batch_size, self.num_output_mixture_components, 2, 2))

        mvn = tfd.MultivariateNormalFullCovariance(
            loc=mus, covariance_matrix=covar_matrix)
        b = tfd.Bernoulli(probs=es)
        c = tfd.Categorical(probs=pis)

        sampled_e = b.sample()
        sampled_coords = mvn.sample()
        sampled_idx = c.sample()

        idx = tf.stack([tf.range(self.batch_size), sampled_idx], axis=1)
        coords = tf.gather_nd(sampled_coords, idx)
        return tf.concat([coords, tf.cast(sampled_e, tf.float32)], axis=1)
Ejemplo n.º 7
0
 def __call__(self, inputs):
   out = self.get('out', tfkl.Dense, np.prod(self._shape))(inputs)
   out = tf.reshape(out, tf.concat([tf.shape(inputs)[:-1], self._shape], 0))
   out = tf.cast(out, tf.float32)
   if self._dist in ('normal', 'tanh_normal', 'trunc_normal'):
     std = self.get('std', tfkl.Dense, np.prod(self._shape))(inputs)
     std = tf.reshape(std, tf.concat([tf.shape(inputs)[:-1], self._shape], 0))
     std = tf.cast(std, tf.float32)
   if self._dist == 'mse':
     dist = tfd.Normal(out, 1.0)
     return tfd.Independent(dist, len(self._shape))
   if self._dist == 'normal':
     dist = tfd.Normal(out, std)
     return tfd.Independent(dist, len(self._shape))
   if self._dist == 'binary':
     dist = tfd.Bernoulli(out)
     return tfd.Independent(dist, len(self._shape))
   if self._dist == 'tanh_normal':
     mean = 5 * tf.tanh(out / 5)
     std = tf.nn.softplus(std + self._init_std) + self._min_std
     dist = tfd.Normal(mean, std)
     dist = tfd.TransformedDistribution(dist, common.TanhBijector())
     dist = tfd.Independent(dist, len(self._shape))
     return common.SampleDist(dist)
   if self._dist == 'trunc_normal':
     std = 2 * tf.nn.sigmoid((std + self._init_std) / 2) + self._min_std
     dist = common.TruncNormalDist(tf.tanh(out), std, -1, 1)
     return tfd.Independent(dist, 1)
   if self._dist == 'onehot':
     return common.OneHotDist(out)
   raise NotImplementedError(self._dist)
Ejemplo n.º 8
0
    def _build(self, inputs):
        prior = self._prior()
        enc_out = self._encoder(inputs['s'])
        loc = self._loc(enc_out)
        log_scale = self._log_scale(enc_out)
        scale = tf.nn.softplus(
            log_scale + softplus_inverse(1.0)
        )  # idk what this is for. maybe ensuring center around 1.0

        q = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale,
                                       name='code')  # approximate posterior
        q_sample = q.sample(
            self.FLAGS['num_vae_samples'])  # approximate posterior sample
        logits = self._decoder_conv(loc)
        #logits = self._decoder_conv(q_sample)
        phat = tfd.Independent(tfd.Bernoulli(logits=logits),
                               reinterpreted_batch_ndims=len(IMAGE_SHAPE),
                               name="image")
        return dict(prior=prior,
                    q=q,
                    q_mean=loc,
                    q_sample=q_sample,
                    phat=phat,
                    phi_s=loc,
                    logits=logits)
Ejemplo n.º 9
0
 def sample(self, time, outputs, state, name=None):
     """Gets a sample for one step."""
     with ops.name_scope(name, "ScheduledOutputTrainingHelperSample",
                         [time, outputs, state]):
         sampler = tfpd.Bernoulli(probs=self._sampling_probability)
         return sampler.sample(sample_shape=self.batch_size,
                               seed=self._seed)
 def testBernoulliLogProb(self, logits, n):
   rv = ed.Bernoulli(logits)
   dist = tfd.Bernoulli(logits)
   x = rv.distribution.sample(n)
   rv_log_prob, dist_log_prob = self.evaluate(
       [rv.distribution.log_prob(x), dist.log_prob(x)])
   self.assertAllEqual(rv_log_prob, dist_log_prob)
Ejemplo n.º 11
0
        def decoder(state_sample, observation_dist="gaussian"):
            """Compute the data distribution of an observation from its state [1]."""
            check_in(
                "observation_dist",
                observation_dist,
                ("gaussian", "laplace", "bernoulli", "multinomial"),
            )

            timesteps = tf.shape(state_sample)[1]

            if self.pixel_observations:
                # original decoder from [1] for deepmind lab envs
                hidden = tf.layers.dense(state_sample, 1024, None)
                kwargs = dict(strides=2, activation=tf.nn.relu)
                hidden = tf.reshape(hidden, [-1, 1, 1, hidden.shape[-1]])
                # 1 x 1
                hidden = tf.layers.conv2d_transpose(hidden, 128, 5, **kwargs)
                # 5 x 5 x 128
                hidden = tf.layers.conv2d_transpose(hidden, 64, 5, **kwargs)
                # 13 x 13 x 64
                hidden = tf.layers.conv2d_transpose(hidden, 32, 6, **kwargs)
                # 30 x 30 x 32
                mean = 255 * tf.layers.conv2d_transpose(
                    hidden, 3, 6, strides=2, activation=tf.nn.sigmoid)
                # 64 x 64 x 3
                assert mean.shape[1:].as_list() == [64, 64, 3], mean.shape
            else:
                # decoder for gridworlds / structured observations
                hidden = state_sample
                d = self._hidden_layer_size
                for _ in range(4):
                    hidden = tf.layers.dense(hidden, d, tf.nn.relu)
                mean = tf.layers.dense(hidden, np.prod(self.data_shape), None)

            mean = tf.reshape(mean, [-1, timesteps] + list(self.data_shape))

            check_in(
                "observation_dist",
                observation_dist,
                ("gaussian", "laplace", "bernoulli", "multinomial"),
            )
            if observation_dist == "gaussian":
                dist = tfd.Normal(mean, self._obs_stddev)
            elif observation_dist == "laplace":
                dist = tfd.Laplace(mean, self._obs_stddev / np.sqrt(2))
            elif observation_dist == "bernoulli":
                dist = tfd.Bernoulli(probs=mean)
            else:
                mean = tf.reshape(mean, [-1, timesteps] +
                                  [np.prod(list(self.data_shape))])
                dist = tfd.Multinomial(total_count=1, probs=mean)
                reshape = tfp.bijectors.Reshape(
                    event_shape_out=list(self.data_shape))
                dist = reshape(dist)
                return dist

            dist = tfd.Independent(dist, len(self.data_shape))
            return dist
Ejemplo n.º 12
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.bernoulli.Bernoulli(logits=self.logits,
                                        probs=self.probs)
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Bernoulli(logits=self.logits, probs=self.probs)
Ejemplo n.º 13
0
 def _common(cls, node, **kwargs):
     x = kwargs["tensor_dict"][node.inputs[0]]
     dtype = node.attrs.get("dtype", x.dtype)
     dist = tfd.Bernoulli(probs=x, dtype=dtype)
     if 'seed' in node.attrs:
         ret = dist.sample(seed=int(node.attrs.get('seed')))
     else:
         ret = dist.sample()
     return [tf.cast(tf.reshape(ret, x.shape), dtype)]
Ejemplo n.º 14
0
 def __call__(self, features):
     x = features
     for index in range(self._layers):
         x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
     x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x)
     x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0))
     if self._dist == 'normal':
         return tfd.Independent(tfd.Normal(x, 1), len(self._shape))
     elif self._dist == 'binary':
         return tfd.Independent(tfd.Bernoulli(x), len(self._shape))
     raise NotImplementedError(self._dist)
Ejemplo n.º 15
0
 def __call__(self, observation, action):
     x = self._dynamics(tf.concat([observation, action], axis=1))
     # TODO (yarden): maybe it's better to feed the reward and terminals s, a instead of x.
     # The world model predicts the difference between next_observation and observation.
     return dict(next_observation=tfd.MultivariateNormalDiag(
         loc=self._next_observation_residual_mu(x) +
         tf.stop_gradient(observation),
         scale_diag=self._next_observation_stddev(x)),
                 reward=tfd.Normal(loc=self._reward_mu(x), scale=1.0),
                 terminal=tfd.Bernoulli(logits=self._terminal_logit(x),
                                        dtype=tf.float32))
Ejemplo n.º 16
0
def generate_logistic_reg_data(
    n_sample, weight, intercept = 0, 
    dtype = tf.float64, seed = None):
    weight = tf.constant(weight, dtype = dtype)
    weight = tf.reshape(weight, shape = (-1, 1))
    n_feature = weight.shape[0]
    x = tf.random.normal(shape = (n_sample, n_feature),
                         seed = seed, dtype = dtype)
    logits = intercept + x @ weight
    y = tfd.Bernoulli(logits=logits, dtype=dtype).sample()
    return x, y
Ejemplo n.º 17
0
 def __call__(self, features):
     kwargs = dict(strides=2, activation=self._act)
     x = self.get('h1', tfkl.Dense, 8 * self._depth, None)(features)
     x = tf.reshape(x, [-1, 1, 1, 8 * self._depth])
     x = self.get('h2', tfkl.Conv2DTranspose, 4 * self._depth, 5, **kwargs)(x)
     x = self.get('h3', tfkl.Conv2DTranspose, 2 * self._depth, 5, **kwargs)(x)
     x = self.get('h4', tfkl.Conv2DTranspose, 1 * self._depth, 6, **kwargs)(x)
     x = self.get('h5', tfkl.Conv2DTranspose, 1, 6, **kwargs)(x)
     shape = tf.concat([tf.shape(features)[:-1], self._shape], axis=0)
     x = tf.reshape(x, shape)
     return tfd.Independent(tfd.Bernoulli(x), 3)  # last 3 dimensions (row, col, chan) define 1 pixel
Ejemplo n.º 18
0
 def sample(self, time, outputs, state, name=None):
     """Gets a sample for one step."""
     with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
                         [time, outputs, state]):
         # Return -1s where we did not sample, and sample_ids elsewhere
         select_sampler = tfpd.Bernoulli(probs=self._sampling_probability,
                                         dtype=dtypes.bool)
         select_sample = select_sampler.sample(sample_shape=self.batch_size,
                                               seed=self._scheduling_seed)
         sample_id_sampler = tfpd.Categorical(logits=outputs)
         return array_ops.where(select_sample,
                                sample_id_sampler.sample(seed=self._seed),
                                gen_array_ops.fill([self.batch_size], -1))
Ejemplo n.º 19
0
Archivo: nn.py Proyecto: xlnwel/d2rl
    def call(self, x):
        x = self._layers(x)
        if not getattr(self, '_has_cnn', None):
            rbd = 0 if x.shape[
                -1] == 1 else 1  # #reinterpreted batch dimensions
            x = tf.squeeze(x)
            if self._dist == 'normal':
                return tfd.Independent(tfd.Normal(x, 1), rbd)
            if self._dist == 'binary':
                return tfd.Independent(tfd.Bernoulli(x), rbd)
            raise NotImplementedError(self._dist)

        return x
Ejemplo n.º 20
0
    def decoder(codes):
        """Builds a distribution over images given codes.

    Args:
      codes: A `Tensor` representing the inputs to be decoded, of shape `[...,
        code_size]`.

    Returns:
      decoder_distribution: A multivariate `Bernoulli` distribution.
    """
        logits = decoder_net(codes)
        return tfd.Independent(tfd.Bernoulli(logits=logits),
                               reinterpreted_batch_ndims=len(output_shape),
                               name="decoder_distribution")
Ejemplo n.º 21
0
 def __call__(self, x):
     n_sample = len(x)
     joint_prob = tfd.JointDistributionSequential([
         tfd.Independent(
             tfd.Normal(
                 loc = tf.zeros((n_sample, n_factor), dtype=self.dtype),
                 scale = 1.0), 
             reinterpreted_batch_ndims=1),
         lambda eta: tfd.Independent(
             tfd.Bernoulli(
                 logits= self.intercept + eta @ tf.transpose(self.loading), 
                 dtype=self.dtype), 
             reinterpreted_batch_ndims=1)])             
     joint_prob._to_track=self
     return joint_prob
Ejemplo n.º 22
0
 def __call__(self, features):
     x = features
     for index in range(self.layers_num):  # 3
         x = self.get(f"h{index}", tf.keras.layers.Dense, self._units,
                      self._act)(x)
         # print("x:",x.shape) #  (15, 1250, 400)
     x = self.get(f"hout", tf.keras.layers.Dense, np.prod(self._shape))(x)
     # print("x:",x.shape)
     x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0))
     # print("x:",x.shape)
     if self._dist == "normal":
         return tfd.Independent(tfd.Normal(x, 1), len(self._shape))
     if self._dist == "binary":
         return tfd.Independent(tfd.Bernoulli(x), len(self._shape))
     raise NotImplementedError(self._dist)
Ejemplo n.º 23
0
def calculate_log_px_z(x, x_logit, pi):
    sample_size = x_logit.shape[4]
    n_required = x_logit.shape[5]
    # x_broad = tf.repeat(tf.expand_dims(x, 4), axis=4, repeats=pi.shape[1])
    # x_broad = tf.reshape(x_broad, shape=(16, 4))
    # x_logit = tf.reshape(x_logit, shape=(16, 4))
    x_broad = tf.repeat(tf.expand_dims(x, -1), axis=-1, repeats=sample_size)
    x_broad = tf.repeat(tf.expand_dims(x_broad, -1), axis=-1, repeats=n_required)
    dist = tfpd.Bernoulli(logits=x_logit)
    log_px_z = dist.log_prob(x_broad)
    log_px_z = tf.reduce_sum(log_px_z, axis=(1, 2, 3))
    log_px_z = tf.reduce_mean(log_px_z, axis=1)
    log_px_z = tf.reduce_sum(pi[:, :, 0, 0] * log_px_z, axis=1)
    log_px_z = tf.reduce_mean(log_px_z, axis=0)
    return log_px_z
    def decode(code, data_shape):
        """
        :param code: number of code units
        :param data_shape: dimensionality of the input data
        :return: void
        """
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            logit = tf.reshape(
                tf.layers.dense(
                    tf.layers.dense(
                        tf.layers.dense(code, 256, tf.nn.relu), 784, tf.nn.relu
                    ), np.prod(data_shape)
                ), [-1] + data_shape
            )

            return distributions.Independent(distributions.Bernoulli(logit), 2)
Ejemplo n.º 25
0
    def perform_fwd_pass(self):
        mean, log_var = self.nets.encoder(self.x)
        stddev = tf.exp(0.5 * log_var)

        qz_x = tfpd.Normal(loc=mean, scale=stddev)
        z = qz_x.sample()

        logits = self.nets.decoder(z)
        px_z = tfpd.Bernoulli(logits=logits)

        p_z = tfpd.Normal(loc=tf.zeros_like(z), scale=tf.ones_like(z))
        kl = tf.reduce_sum(tfpd.kl_divergence(qz_x, p_z), axis=1)
        expected_log_likelihood = tf.reduce_sum(px_z.log_prob(self.x),
                                                axis=(1, 2, 3))

        self.elbo = tf.reduce_mean(expected_log_likelihood - kl, axis=0)
Ejemplo n.º 26
0
    def rollout_in_dream(self, z_init, h_init, video=False):
        """
        Inputs:
            h_init: (L, B, 1)
            z_init: (L, B, latent_dim * n_atoms)
            done_init: (L, B, 1)
        """
        L, B = h_init.shape[:2]

        horizon = self.config.imagination_horizon

        z, h = tf.reshape(z_init, [L * B, -1]), tf.reshape(h_init, [L * B, -1])
        feats = tf.concat([z, h], axis=-1)

        #: s_t, a_t, s_t+1
        trajectory = {"state": [], "action": [], 'next_state': []}

        for t in range(horizon):

            actions = tf.cast(self.policy.sample(feats), dtype=tf.float32)

            trajectory["state"].append(feats)
            trajectory["action"].append(actions)

            h = self.world_model.step_h(z, h, actions)
            z, _ = self.world_model.rssm.sample_z_prior(h)
            z = tf.reshape(z, [L * B, -1])

            feats = tf.concat([z, h], axis=-1)
            trajectory["next_state"].append(feats)

        trajectory = {k: tf.stack(v, axis=0) for k, v in trajectory.items()}

        #: reward_head(s_t+1) -> r_t
        #: Distribution.mode()は確立最大値を返すのでNormalの場合は
        #: trjactory["reward"] == rewards
        rewards = self.world_model.reward_head(trajectory['next_state'])
        trajectory["reward"] = rewards

        disc_logits = self.world_model.discount_head(trajectory['next_state'])
        trajectory["discount"] = tfd.Independent(
            tfd.Bernoulli(logits=disc_logits),
            reinterpreted_batch_ndims=1).mean()

        return trajectory
Ejemplo n.º 27
0
def feed_forward(
    features, data_shape, num_layers=2, activation=tf.nn.relu,
    mean_activation=None, stop_gradient=False, trainable=True, units=100,
    std=1.0, low=-1.0, high=1.0, dist='normal', min_std=1e-2, init_std=1.0):
  hidden = features
  if stop_gradient:
    hidden = tf.stop_gradient(hidden)
  for _ in range(num_layers):
    hidden = tf.layers.dense(hidden, units, activation, trainable=trainable)
  mean = tf.layers.dense(
      hidden, int(np.prod(data_shape)), mean_activation, trainable=trainable)
  mean = tf.reshape(mean, tools.shape(features)[:-1] + data_shape)
  if std == 'learned':
    std = tf.layers.dense(
        hidden, int(np.prod(data_shape)), None, trainable=trainable)
    init_std = np.log(np.exp(init_std) - 1)
    std = tf.nn.softplus(std + init_std) + min_std
    std = tf.reshape(std, tools.shape(features)[:-1] + data_shape)
  if dist == 'normal':
    dist = tfd.Normal(mean, std)
    dist = tfd.Independent(dist, len(data_shape))
  elif dist == 'deterministic':
    dist = tfd.Deterministic(mean)
    dist = tfd.Independent(dist, len(data_shape))
  elif dist == 'binary':
    dist = tfd.Bernoulli(mean)
    dist = tfd.Independent(dist, len(data_shape))
  elif dist == 'trunc_normal':
    # https://www.desmos.com/calculator/rnksmhtgui
    dist = tfd.TruncatedNormal(mean, std, low, high)
    dist = tfd.Independent(dist, len(data_shape))
  elif dist == 'tanh_normal':
    # https://www.desmos.com/calculator/794s8kf0es
    dist = distributions.TanhNormal(mean, std)
  elif dist == 'tanh_normal_tanh':
    # https://www.desmos.com/calculator/794s8kf0es
    mean = 5.0 * tf.tanh(mean / 5.0)
    dist = distributions.TanhNormal(mean, std)
  elif dist == 'onehot_score':
    dist = distributions.OneHot(mean, gradient='score')
  elif dist == 'onehot_straight':
    dist = distributions.OneHot(mean, gradient='straight')
  else:
    raise NotImplementedError(dist)
  return dist
Ejemplo n.º 28
0
    def __call__(self, codes):
        """Build decoder which takes a code and returns a distribution over images.

    Args:
      codes: A `float`-like `Tensor` representing the inputs to be decoded.
        The first dimension (axis 0) indexes batch elements; all other
        dimensions index event elements.

    Returns:
      decoder: A multivariate `Bernoulli` distribution.
    """
        net = self.decoder_net(codes)
        new_shape = tf.concat([tf.shape(net)[:-1], self.output_shape], axis=0)
        logits = tf.reshape(net, shape=new_shape)
        return tfd.Independent(tfd.Bernoulli(logits=logits),
                               reinterpreted_batch_ndims=len(
                                   self.output_shape),
                               name="decoder_distribution")
Ejemplo n.º 29
0
    def _compute_log_loss(self, y_true, y_pred, mode):
        """
        Inputs:
            y_true: (L, B, 1)
            y_pred: (L, B, 1)
            mode: "reward" or "discount"
        """
        if mode == "discount":
            dist = tfd.Independent(tfd.Bernoulli(logits=y_pred),
                                   reinterpreted_batch_ndims=1)
        elif mode == "reward":
            dist = tfd.Independent(tfd.Normal(loc=y_pred, scale=1.),
                                   reinterpreted_batch_ndims=1)

        log_prob = dist.log_prob(y_true)

        loss = tf.reduce_mean(log_prob)

        return loss
Ejemplo n.º 30
0
  def call(self, codes):
    """Builds a distribution over images given codes.

    Args:
      codes: A `Tensor` representing the inputs to be decoded, of shape `[...,
        code_size]`.

    Returns:
      decoder_distribution: A multivariate `Bernoulli` distribution.
    """
    num_samples, batch_size, latent_size, code_size = common_layers.shape_list(
        codes)
    codes = tf.reshape(codes,
                       [num_samples * batch_size, latent_size, code_size])
    logits = self.decoder_net(codes)
    logits = tf.reshape(logits,
                        [num_samples, batch_size] + list(self.image_shape))
    return tfd.Independent(tfd.Bernoulli(logits=logits),
                           reinterpreted_batch_ndims=len(self.image_shape),
                           name="decoder_distribution")