Beispiel #1
0
 def _get_prior_params(self):
     log_var = snt.TrainableVariable(
         [],
         name='prior_var_scale',
         initializers={'w': tf.constant_initializer(np.log(1.0))})()
     self._prior_var = tf.ones([self._memory_size
                                ]) * tf.exp(log_var) + EPSILON
     prior_cov = tf.matrix_diag(self._prior_var)
     prior_mean = snt.TrainableVariable(
         [self._memory_size, self._code_size],
         name='prior_mean',
         initializers={
             'w': tf.truncated_normal_initializer(mean=0.0, stddev=1.0)
         })()
     return prior_mean, prior_cov
Beispiel #2
0
  def __init__(self, discriminator, generator,
               num_z_iters=None, z_step_size=None,
               z_project_method=None, optimisation_cost_weight=None):
    """Constructs the module.

    Args:
      discriminator: The discriminator network. A sonnet module. See `nets.py`.
      generator: The generator network. A sonnet module. For examples, see
        `nets.py`.
      num_z_iters: an integer, the number of latent optimisation steps.
      z_step_size: an integer, latent optimisation step size.
      z_project_method: the method for projecting latent after optimisation,
        a string from {'norm', 'clip'}.
      optimisation_cost_weight: a float, how much to penalise the distance of z
        moved by latent optimisation.
    """
    self._discriminator = discriminator
    self.generator = generator
    self.num_z_iters = num_z_iters
    self.z_project_method = z_project_method
    if z_step_size:
      self._log_step_size_module = snt.TrainableVariable(
          [],
          initializers={'w': tf.constant_initializer(math.log(z_step_size))})
      self.z_step_size = tf.exp(self._log_step_size_module())
    self._optimisation_cost_weight = optimisation_cost_weight
Beispiel #3
0
    def __init__(self,
                 code_size,
                 memory_size,
                 num_opt_iters=1,
                 w_prior_stddev=1.0,
                 obs_noise_stddev=1.0,
                 sample_w=False,
                 sample_M=False,
                 name='KanervaMemory'):
        """Initialise the memory module.

    Args:
      code_size: Integer specifying the size of each encoded input.
      memory_size: Integer specifying the total number of rows in the memory.
      num_opt_iters: Integer specifying the number of optimisation iterations.
      w_prior_stddev: Float specifying the standard deviation of w's prior.
      obs_noise_stddev: Float specifying the standard deviation of the
        observational noise.
      sample_w: Boolean specifying whether to sample w or simply take its mean.
      sample_M: Boolean specifying whether to sample M or simply take its mean.
      name: String specfying the name of this module.
    """
        super(KanervaMemory, self).__init__(name=name)
        self._memory_size = memory_size
        self._code_size = code_size
        self._num_opt_iters = num_opt_iters
        self._sample_w = sample_w
        self._sample_M = sample_M
        self._w_prior_stddev = tf.constant(w_prior_stddev)

        with self._enter_variable_scope():
            log_w_stddev = snt.TrainableVariable(
                [],
                name='w_stddev',
                initializers={'w': tf.constant_initializer(np.log(0.3))})()
            if obs_noise_stddev > 0.0:
                self._obs_noise_stddev = tf.constant(obs_noise_stddev)
            else:
                log_obs_stddev = snt.TrainableVariable(
                    [],
                    name='obs_stdddev',
                    initializers={'w': tf.constant_initializer(np.log(1.0))})()
                self._obs_noise_stddev = tf.exp(log_obs_stddev)
        self._w_stddev = tf.exp(log_w_stddev)
        self._w_prior_dist = tfp.distributions.MultivariateNormalDiag(
            loc=tf.zeros([self._memory_size]),
            scale_identity_multiplier=self._w_prior_stddev)
Beispiel #4
0
 def _initial_state_var(self):
     if self._train_initial_state:
         return snt.TrainableVariable(
             [1, self._input_channels],
             initializers=dict(w=tf.zeros_initializer),
             name='initial_state')()
     else:
         return tf.zeros([1, self._input_channels])
Beispiel #5
0
  def __init__(self, metric_net, generator,
               num_z_iters, z_step_size, z_project_method):
    """Constructs the module.

    Args:
      metric_net: the measurement network.
      generator: The generator network. A sonnet module. For examples, see
        `nets.py`.
      num_z_iters: an integer, the number of latent optimisation steps.
      z_step_size: an integer, latent optimisation step size.
      z_project_method: the method for projecting latent after optimisation,
        a string from {'norm', 'clip'}.
    """

    self._measure = metric_net
    self.generator = generator
    self.num_z_iters = num_z_iters
    self.z_project_method = z_project_method
    self._log_step_size_module = snt.TrainableVariable(
        [],
        initializers={'w': tf.constant_initializer(math.log(z_step_size))})
    self.z_step_size = tf.exp(self._log_step_size_module())