Exemplo n.º 1
0
 def _get_layer_tf_variables(self,
                             prefix,
                             initializer,
                             layers=None,
                             bias=False):
     layers = layers or self.layers
     return pdict({
         layer: tf.get_variable(name='{0}_{1}'.format(prefix, layer),
                                shape=shape if not bias else shape[1:],
                                initializer=initializer)
         for layer, shape in layers.items()
     })
Exemplo n.º 2
0
    def decoder_catgen_dist(self, critic_dist, z):
        """Contrastive CatGen decoder.

    Args:
      critic_dist: Critic distribution with shape [B T (Z)]
      z: representaiton vector with shape [B T Z].

    Returns:
      A pdict containing CatGen categorical distribution and log(c(y|z)).
    """
        dist = pdict()

        def ood_logits(logits, ood_log_density):
            logits = tf.concat(
                [logits,
                 tf.ones_like(logits[..., :1]) * ood_log_density],
                axis=-1)
            return logits

        def log_prob():
            batch_shape = critic_dist.batch_shape
            batch_size, seq_size = batch_shape[0], batch_shape[1]
            z_dim = z.shape[-1]
            # z_squash: [BxT Z]
            z_squash = tf.reshape(z, shape=(batch_size * seq_size, z_dim))
            # logits: [BxT B T], z_squash -> [BxT 1 1 Z]
            logits = critic_dist.log_prob(z_squash[..., None, None, :])
            # logits: [B T BxT]
            logits = tf.reshape(logits,
                                shape=(batch_size, seq_size,
                                       batch_size * seq_size))

            if self.ood_log_density is not None and self.ood_log_density < 0:
                logits = ood_logits(logits, self.ood_log_density)

            # temperature
            if self.tau != 1.0:
                logits = logits / self.tau

            # dist.cat_dist: [B T (BxT)]
            dist.cat_dist = tfd.Categorical(logits=logits)

            inds = tf.range(batch_size * seq_size)
            inds = tf.reshape(inds, shape=(batch_size, seq_size))
            # log_probs: [B T]
            log_probs = dist.cat_dist.log_prob(inds)
            mi_upper_bound = tf.math.log(
                tf.cast(batch_size * seq_size, tf.float32))
            log_probs = log_probs + mi_upper_bound
            return log_probs

        dist.log_prob = log_prob
        return dist
Exemplo n.º 3
0
def default_config(**kwargs):
    """Function (hack) that turns all the values defined above into a
    dictionary.

    Grabs all the values in the ``globals`` dictionary and adds them to the
    returned dictionary if it is a relevant variable.

    The function also accepts keyword arguments that will be added to the config
    at the end, overwriting the existing value, or adding the key/value pair.

    Returns
    -------
    default_config : pdict
    	Default configuration as defined by in the default_configuration.py file
    """
    # Configuration in a dictionary format
    locals_dict = globals()
    # Empty dict we will fill
    _dict = pdict()

    # Remove all the non configuration
    for key, val in locals_dict.items():
        # Skip if the key is a dunder key
        if key.startswith('__') and key.endswith('__'):
            continue
        # Skip if the value is a module
        if isinstance(val, ModuleType):
            continue
        # Skip if the value is a callable (presumably a function)
        if callable(val):
            continue
        # Skip if the value is the logger
        if val is logger:
            continue

        # Passed all skips, add to the dict
        _dict[key] = val

    # Add any kwargs to the dictionary, overwriting any conflicting values
    _dict.update(**kwargs)
    return _dict