def train_from_samples(self, patches,num_parts):
        #from pnet.latent_bernoulli_mm import LatentBernoulliMM
        from pnet.bernoullimm import BernoulliMM
        min_prob = self._settings.get('min_prob', 0.01)
        flatpatches = patches.reshape((patches.shape[0], -1))
        parts = np.ones((num_parts,) + patches.shape[1:])
        if 0:
            mm = BernoulliMM(n_components=num_parts, n_iter=20, tol=1e-15,n_init=2, random_state=0, min_prob=min_prob, verbose=False)
            print(mm.fit(flatpatches))
            print('AIC', mm.aic(flatpatches))
            print('BIC', mm.bic(flatpatches))
            #import pdb; pdb.set_trace()
            parts = mm.means_.reshape((num_parts,)+patches.shape[1:])
            #self._weights = mm.weights_
        else:
            
            from pnet.bernoulli import em
            ret = em(flatpatches, num_parts,20,numpy_rng=self._settings.get('em_seed',0),verbose=True)
            parts = ret[1].reshape((num_parts,) + patches.shape[1:])
            self._weights = np.arange(self._num_parts)

            #self._weights = mm.weights
            
        # Calculate entropy of parts
        Hall = (parts * np.log(parts) + (1 - parts) * np.log(1 - parts))
        H = -np.apply_over_axes(np.mean, Hall, [1, 2, 3])[:,0,0,0]

        # Sort by entropy
        II = np.argsort(H)

        parts[:] = parts[II]
        #self._train_info['entropy'] = H[II]
        return parts
    def train_from_samples(self, patches, original_patches):
        # from pnet.latent_bernoulli_mm import LatentBernoulliMM
        from pnet.bernoullimm import BernoulliMM

        print(patches.shape)
        min_prob = self._settings.get("min_prob", 0.01)
        # num_permutation = self._shifting_shape[0] * self._shifting_shape[1]
        # parts = np.ones((self._num_true_parts * num_permutation ,) + patches.shape[2:])
        parts = np.ones((self._num_parts,) + patches[0].shape)
        d = np.prod(patches.shape[1:])
        # print(d,num_permutation)
        if 0:
            # \permutation = np.empty((num_permutation, num_permutation * d),dtype = np.int_)
            for a in range(num_permutation):
                if a == 0:
                    permutation[a] = np.arange(num_permutation * d)
                else:
                    permutation[a] = np.roll(permutation[a - 1], d)
        flatpatches = patches.reshape((patches.shape[0], -1))
        print(flatpatches.shape)
        if 0:
            mm = BernoulliMM(
                n_components=num_parts, n_iter=20, tol=1e-15, n_init=2, random_state=0, min_prob=min_prob, verbose=False
            )
            print(mm.fit(flatpatches))
            print("AIC", mm.aic(flatpatches))
            print("BIC", mm.bic(flatpatches))
            # import pdb; pdb.set_trace()
            parts = mm.means_.reshape((num_parts,) + patches.shape[1:])
            # self._weights = mm.weights_
        elif 0:

            from pnet.bernoulli import em

            print("before EM")

            ret = em(
                flatpatches,
                self._num_true_parts,
                10,
                mu_truncation=min_prob,
                permutation=permutation,
                numpy_rng=self._settings.get("em_seed", 0),
                verbose=True,
            )
            comps = ret[3]
            parts = ret[1].reshape((self._num_true_parts * num_permutation,) + patches.shape[2:])
            self._weights = np.arange(self._num_parts)
        else:
            rng = np.random.RandomState(self._settings.get("em_seed", 0))
            from pnet.latentShiftEM import LatentShiftEM

            # from latentShiftEM import latentShiftEM
            result = LatentShiftEM(
                flatpatches,
                num_mixture_component=self._num_parts,
                parts_shape=(self._part_shape[0], self._part_shape[1], 8),
                region_shape=(self._sample_shape[1], self._sample_shape[1], 8),
                shifting_shape=self._shifting_shape,
                max_num_iteration=25,
                loglike_tolerance=1e-3,
                mu_truncation=(1, 1),
                additional_mu=None,
                permutation=None,
                numpy_rng=rng,
                verbose=True,
            )
            comps = result[3]
            print(comps.shape)
            print(original_patches.shape)
            print(result[1].shape)
            parts = result[1].reshape((self._num_parts, self._part_shape[0], self._part_shape[1], 8))
            self._bkg_probability = result[4]
        self._parts = parts
        print(comps[:50, 0])
        print(comps[:50, 1])
        self._visparts = np.asarray(
            [original_patches[comps[:, 0] == k, comps[comps[:, 0] == k][:, 1]].mean(0) for k in range(self._num_parts)]
        )
        print(self._visparts.shape)
        import amitgroup.plot as gr

        gr.images(self._visparts, zero_to_one=False, show=False, vmin=0, vmax=1, fileName="moduleShiftingParts1.png")
        return parts