def __init__(self, model, num_samples, betas=3, data=None): """ The constructor initializes the IPT trainer with a given model and data. :param model: The model to sample from. :type model: Valid model class. :param num_samples: The number of Samples to produce. .. Note:: you should use the batchsize. :type num_samples: int :param betas: List of inverse temperatures to sample from. If a scalar is given, the temperatures will be set \ linearly from 0.0 to 1.0 in 'betas' steps. :type betas: int, numpy array [num betas] :param data: Data for initialization, only has effect if the centered gradient is used. :type data: numpy array [num. samples x input dim] """ # Call super constructor of CD super(IPT, self).__init__(model, data) if numx.isscalar(betas): self.sampler = sampler.IndependentParallelTemperingSampler(model, num_samples, betas, None) else: self.sampler = sampler.IndependentParallelTemperingSampler(model, num_samples, betas.shape[0], betas)
def __init__(self, model, batch_size, num_chains=3, betas=None): # Call constructor of superclass super(IPT, self).__init__(model=model, batch_size=batch_size) self.sampler = RBM_SAMPLER.IndependentParallelTemperingSampler( self.rbm, self.batch_size, num_chains, betas)
def test_Independent_Parallel_Tempering_sampler(self): sys.stdout.write('RBM Sampler -> Performing IndependentParallelTemperingSampler test ... ') sys.stdout.flush() numx.random.seed(42) sampler = Sampler.IndependentParallelTemperingSampler(self.bbrbm, 10, 10) probCD1, probCD2, probCS1, probCS2, probCS3, probCS4, sumProbs = self.execute_sampler(sampler, self.num_samples) assert numx.all(numx.abs(1.0 / 4.0 - probCD1) < self.epsilon) assert numx.all(numx.abs(1.0 / 4.0 - probCD2) < self.epsilon) assert numx.all(numx.abs(1.0 / 8.0 - probCS1) < self.epsilon) assert numx.all(numx.abs(1.0 / 8.0 - probCS2) < self.epsilon) assert numx.all(numx.abs(1.0 / 8.0 - probCS3) < self.epsilon) assert numx.all(numx.abs(1.0 / 8.0 - probCS4) < self.epsilon) assert numx.all(numx.abs(1.0 - sumProbs) < self.epsilon) print('successfully passed!') sys.stdout.flush()