Beispiel #1
0
    def train_from_samples(self, patches,patches_original):
        #from pnet.latent_bernoulli_mm import LatentBernoulliMM
        from pnet.bernoullimm import BernoulliMM
        min_prob = self._settings.get('min_prob', 0.01)
        flatpatches = patches.reshape((patches.shape[0], -1))
        
        if 1:
            mm = BernoulliMM(n_components=self._num_parts, n_iter=20, tol=1e-15,n_init=2, random_state=self._settings.get('em_seed',0), min_prob=min_prob, verbose=False)
            mm.fit(flatpatches)
            print(mm.fit(flatpatches))
            #print('AIC', mm.aic(flatpatches))
            #print('BIC', mm.bic(flatpatches))
            
            if 0:
                # Draw samples
                #size = (20, 20)
                import gv
                import os
                N = 10000
                D = np.prod(self._part_shape)
                size = (20, 20)
                grid = gv.plot.ImageGrid(size[0], size[1], self._part_shape)
                samp = mm.sample(n_samples=np.prod(size)).reshape((-1,) + self._part_shape)
                samples = mm.sample(n_samples=N).reshape((-1,) + self._part_shape)
                print('samples', samples.shape)

                types = np.asarray(list(gv.multirange(*[2]*D))).reshape((-1,) + self._part_shape)
                th = np.clip(types, 0.01, 0.99)

                t = th[:,np.newaxis]
                x = samples[np.newaxis]

                llh0 = x * np.log(t) + (1 - x) * np.log(1 - t)
                counts0 = np.bincount(np.argmax(llh0.sum(-1).sum(-1), 0), minlength=th.shape[0])
                
                x1 = patches[np.newaxis,...,0]
                llh1 = x1 * np.log(t) + (1 - x1) * np.log(1 - t)
                counts1 = np.bincount(np.argmax(llh1.sum(-1).sum(-1), 0), minlength=th.shape[0])

                #import pdb; pdb.set_trace()

                w0 = counts0 / counts0.sum()
                w1 = counts1 / counts1.sum()

                print('w0', w0)
                print('w1', w1) 
                #import pdb; pdb.set_trace()


            #import pdb; pdb.set_trace()
            self._parts = mm.means_.reshape((self._num_parts,)+patches.shape[1:])
            self._weights = mm.weights_
        else:
            #mm = ag.stats.BernoulliMixture(self._num_parts, flatpatches, max_iter=2000)
            #mm.run_EM(1e-6, min_probability=min_prob)
            #self._parts = mm.templates.reshape((self._num_parts,)+patches.shape[1:])
            #self._weights = mm.weights
            from pnet.bernoulli import em
            ret = em(flatpatches, self._num_parts,20,numpy_rng=self._settings.get('em_seed',0),verbose=True)
            self._parts = ret[1].reshape((self._num_parts,) + patches.shape[1:])
            self._weights = np.arange(self._num_parts)
        if 0:
            predictedGroups = mm.predict(flatpatches) 
        
        # Calculate entropy of parts
        Hall = (self._parts * np.log(self._parts) + (1 - self._parts) * np.log(1 - self._parts))
        H = -np.apply_over_axes(np.mean, Hall, [1, 2, 3])[:,0,0,0]