def _get_prior_params(self): log_var = snt.TrainableVariable( [], name='prior_var_scale', initializers={'w': tf.constant_initializer(np.log(1.0))})() self._prior_var = tf.ones([self._memory_size ]) * tf.exp(log_var) + EPSILON prior_cov = tf.matrix_diag(self._prior_var) prior_mean = snt.TrainableVariable( [self._memory_size, self._code_size], name='prior_mean', initializers={ 'w': tf.truncated_normal_initializer(mean=0.0, stddev=1.0) })() return prior_mean, prior_cov
def __init__(self, discriminator, generator, num_z_iters=None, z_step_size=None, z_project_method=None, optimisation_cost_weight=None): """Constructs the module. Args: discriminator: The discriminator network. A sonnet module. See `nets.py`. generator: The generator network. A sonnet module. For examples, see `nets.py`. num_z_iters: an integer, the number of latent optimisation steps. z_step_size: an integer, latent optimisation step size. z_project_method: the method for projecting latent after optimisation, a string from {'norm', 'clip'}. optimisation_cost_weight: a float, how much to penalise the distance of z moved by latent optimisation. """ self._discriminator = discriminator self.generator = generator self.num_z_iters = num_z_iters self.z_project_method = z_project_method if z_step_size: self._log_step_size_module = snt.TrainableVariable( [], initializers={'w': tf.constant_initializer(math.log(z_step_size))}) self.z_step_size = tf.exp(self._log_step_size_module()) self._optimisation_cost_weight = optimisation_cost_weight
def __init__(self, code_size, memory_size, num_opt_iters=1, w_prior_stddev=1.0, obs_noise_stddev=1.0, sample_w=False, sample_M=False, name='KanervaMemory'): """Initialise the memory module. Args: code_size: Integer specifying the size of each encoded input. memory_size: Integer specifying the total number of rows in the memory. num_opt_iters: Integer specifying the number of optimisation iterations. w_prior_stddev: Float specifying the standard deviation of w's prior. obs_noise_stddev: Float specifying the standard deviation of the observational noise. sample_w: Boolean specifying whether to sample w or simply take its mean. sample_M: Boolean specifying whether to sample M or simply take its mean. name: String specfying the name of this module. """ super(KanervaMemory, self).__init__(name=name) self._memory_size = memory_size self._code_size = code_size self._num_opt_iters = num_opt_iters self._sample_w = sample_w self._sample_M = sample_M self._w_prior_stddev = tf.constant(w_prior_stddev) with self._enter_variable_scope(): log_w_stddev = snt.TrainableVariable( [], name='w_stddev', initializers={'w': tf.constant_initializer(np.log(0.3))})() if obs_noise_stddev > 0.0: self._obs_noise_stddev = tf.constant(obs_noise_stddev) else: log_obs_stddev = snt.TrainableVariable( [], name='obs_stdddev', initializers={'w': tf.constant_initializer(np.log(1.0))})() self._obs_noise_stddev = tf.exp(log_obs_stddev) self._w_stddev = tf.exp(log_w_stddev) self._w_prior_dist = tfp.distributions.MultivariateNormalDiag( loc=tf.zeros([self._memory_size]), scale_identity_multiplier=self._w_prior_stddev)
def _initial_state_var(self): if self._train_initial_state: return snt.TrainableVariable( [1, self._input_channels], initializers=dict(w=tf.zeros_initializer), name='initial_state')() else: return tf.zeros([1, self._input_channels])
def __init__(self, metric_net, generator, num_z_iters, z_step_size, z_project_method): """Constructs the module. Args: metric_net: the measurement network. generator: The generator network. A sonnet module. For examples, see `nets.py`. num_z_iters: an integer, the number of latent optimisation steps. z_step_size: an integer, latent optimisation step size. z_project_method: the method for projecting latent after optimisation, a string from {'norm', 'clip'}. """ self._measure = metric_net self.generator = generator self.num_z_iters = num_z_iters self.z_project_method = z_project_method self._log_step_size_module = snt.TrainableVariable( [], initializers={'w': tf.constant_initializer(math.log(z_step_size))}) self.z_step_size = tf.exp(self._log_step_size_module())