Exemple #1
0
    def __init__(self, model, batch_size, num_chains=3, betas=None):

        # Call constructor of superclass
        super(PT, self).__init__(model=model, batch_size=batch_size)

        self.sampler = RBM_SAMPLER.ParallelTemperingSampler(
            self.rbm, num_chains, betas)
    def __init__(self, model, betas=3, data=None):
        """ The constructor initializes the IPT trainer with a given models anddata.

        :param model: The models to sample from.
        :type model: Valid models class.

        :param betas: List of inverse temperatures to sample from. If a scalar is given, the temperatures will be set \
                      linearly from 0.0 to 1.0 in 'betas' steps.
        :type betas: int, numpy array [num betas]

        :param data: Data for initialization, only has effect if the centered gradient is used.
        :type data: numpy array [num. samples x input dim]
        """
        # Call super constructor of CD
        super(PT, self).__init__(model, data)
        if numx.isscalar(betas):
            self.sampler = sampler.ParallelTemperingSampler(model, betas, None)
        else:
            self.sampler = sampler.ParallelTemperingSampler(
                model, betas.shape[0], betas)
 def test_Parallel_Tempering_sampler(self):
     sys.stdout.write('RBM Sampler -> Performing ParallelTemperingSampler test ... ')
     sys.stdout.flush()
     numx.random.seed(42)
     sampler = Sampler.ParallelTemperingSampler(self.bbrbm, 10)
     probCD1, probCD2, probCS1, probCS2, probCS3, probCS4, sumProbs = self.execute_sampler(sampler, self.num_samples)
     assert numx.all(numx.abs(1.0 / 4.0 - probCD1) < self.epsilon)
     assert numx.all(numx.abs(1.0 / 4.0 - probCD2) < self.epsilon)
     assert numx.all(numx.abs(1.0 / 8.0 - probCS1) < self.epsilon)
     assert numx.all(numx.abs(1.0 / 8.0 - probCS2) < self.epsilon)
     assert numx.all(numx.abs(1.0 / 8.0 - probCS3) < self.epsilon)
     assert numx.all(numx.abs(1.0 / 8.0 - probCS4) < self.epsilon)
     assert numx.all(numx.abs(1.0 - sumProbs) < self.epsilon)
     print('successfully passed!')
     sys.stdout.flush()