Esempio n. 1
0
    def set_test_config(self):
        super().set_test_config()

        self.hmc_test = HamiltonianMonteCarlo(self.config.testing.mcmc,
                                              self.batch_size_te,
                                              testing=True,
                                              keep_samples=True)
Esempio n. 2
0
    def set_training_config(self):

        # Control variate config
        self.use_control_variate = self.config.training.control_variate.get(
            'use_control_variate', True)
        self.use_local_control_variate = self.config.training.control_variate.get(
            'use_local_control_variate', self.use_control_variate)
        self.use_local_control_variate = self.use_local_control_variate if self.use_control_variate else False
        self.control_var_decay = self.config.training.control_variate.get(
            'decay', 0.9)
        self.control_var_independent_iters = self.config.training.control_variate.get(
            'independent_iterations', 3000)

        # Set function to update control variate
        on_ipu_or_cpu = self.device_config[
            'on_ipu'] or 'cpu' in self.device_config['device'].lower()
        if self.use_local_control_variate:
            self.control_var_device = get_device_scope_call(
                '/device:CPU:0' if on_ipu_or_cpu else '/device:GPU:0')
            self.maybe_update_control_variate = self._update_local_control_variate
        elif self.use_control_variate:
            self.control_var_device = self.device_config['scoper']
            self.maybe_update_control_variate = self._update_global_control_variate
        else:
            self.control_var_device = self.device_config['scoper']

            def dont_update_cv(control_variate,
                               idx,
                               elbo_hmc,
                               assign=True,
                               decay=0.9):
                """For consistent inputs/outputs with self._update_global_control_variate()
                and self._update_local_control_variate()"""
                return tf.zeros((), dtype=self.experiment.dtype)

            self.maybe_update_control_variate = dont_update_cv

        # HMC config
        self.hmc_train = HamiltonianMonteCarlo(self.config.training.mcmc,
                                               self.config.batch_size)

        # Set config for normal VAE stuff
        super().set_training_config()
        self.loss_shape = (2, )
        self.train_output_labels = ('Loss [encoder, decoder]',
                                    'Average control variate',
                                    'Average Unnorm. ELBO (HMC)',
                                    'HMC step size')