def train_from_samples(self, patches,patches_original): #from pnet.latent_bernoulli_mm import LatentBernoulliMM from pnet.bernoullimm import BernoulliMM min_prob = self._settings.get('min_prob', 0.01) flatpatches = patches.reshape((patches.shape[0], -1)) if 1: mm = BernoulliMM(n_components=self._num_parts, n_iter=20, tol=1e-15,n_init=2, random_state=self._settings.get('em_seed',0), min_prob=min_prob, verbose=False) mm.fit(flatpatches) print(mm.fit(flatpatches)) #print('AIC', mm.aic(flatpatches)) #print('BIC', mm.bic(flatpatches)) if 0: # Draw samples #size = (20, 20) import gv import os N = 10000 D = np.prod(self._part_shape) size = (20, 20) grid = gv.plot.ImageGrid(size[0], size[1], self._part_shape) samp = mm.sample(n_samples=np.prod(size)).reshape((-1,) + self._part_shape) samples = mm.sample(n_samples=N).reshape((-1,) + self._part_shape) print('samples', samples.shape) types = np.asarray(list(gv.multirange(*[2]*D))).reshape((-1,) + self._part_shape) th = np.clip(types, 0.01, 0.99) t = th[:,np.newaxis] x = samples[np.newaxis] llh0 = x * np.log(t) + (1 - x) * np.log(1 - t) counts0 = np.bincount(np.argmax(llh0.sum(-1).sum(-1), 0), minlength=th.shape[0]) x1 = patches[np.newaxis,...,0] llh1 = x1 * np.log(t) + (1 - x1) * np.log(1 - t) counts1 = np.bincount(np.argmax(llh1.sum(-1).sum(-1), 0), minlength=th.shape[0]) #import pdb; pdb.set_trace() w0 = counts0 / counts0.sum() w1 = counts1 / counts1.sum() print('w0', w0) print('w1', w1) #import pdb; pdb.set_trace() #import pdb; pdb.set_trace() self._parts = mm.means_.reshape((self._num_parts,)+patches.shape[1:]) self._weights = mm.weights_ else: #mm = ag.stats.BernoulliMixture(self._num_parts, flatpatches, max_iter=2000) #mm.run_EM(1e-6, min_probability=min_prob) #self._parts = mm.templates.reshape((self._num_parts,)+patches.shape[1:]) #self._weights = mm.weights from pnet.bernoulli import em ret = em(flatpatches, self._num_parts,20,numpy_rng=self._settings.get('em_seed',0),verbose=True) self._parts = ret[1].reshape((self._num_parts,) + patches.shape[1:]) self._weights = np.arange(self._num_parts) if 0: predictedGroups = mm.predict(flatpatches) # Calculate entropy of parts Hall = (self._parts * np.log(self._parts) + (1 - self._parts) * np.log(1 - self._parts)) H = -np.apply_over_axes(np.mean, Hall, [1, 2, 3])[:,0,0,0]