def train_from_samples(self, patches):
        #from pnet.latent_bernoulli_mm import LatentBernoulliMM
        from pnet.bernoullimm import BernoulliMM
        min_prob = self._settings.get('min_prob', 0.01)

        support_mask = self._settings.get('support_mask')
        if support_mask is not None:
            kp_patches = patches[:,support_mask]
        else:
            kp_patches = patches.reshape((patches.shape[0], -1, patches.shape[-1]))

        if self._settings.get('kp') == 'funky':
            patches = patches[:,::2,::2]

            # Only patches above the threshold
            print('patches', patches.shape)
            patches = patches[np.apply_over_axes(np.sum, patches.astype(np.int64), [1, 2, 3]).ravel() >= self._settings['threshold']]
            print('patches', patches.shape)

        flatpatches = kp_patches.reshape((kp_patches.shape[0], -1))

        mm = BernoulliMM(n_components=self._num_parts, 
                         n_iter=10, 
                         tol=1e-15,
                         n_init=1, 
                         random_state=self._settings.get('em_seed', 0), 
                         min_prob=min_prob, 
                         verbose=False)
        mm.fit(flatpatches)
        logprob, resp = mm.eval(flatpatches)
        comps = resp.argmax(-1)


        Hall = (mm.means_ * np.log(mm.means_) + (1 - mm.means_) * np.log(1 - mm.means_))
        H = -Hall.mean(-1)

        if support_mask is not None:
            self._parts = 0.5 * np.ones((self._num_parts,) + patches.shape[1:])
            self._parts[:,support_mask] = mm.means_.reshape((self._parts.shape[0], -1, self._parts.shape[-1]))
        else:
            self._parts = mm.means_.reshape((self._num_parts,)+patches.shape[1:])
        self._weights = mm.weights_

        # Calculate entropy of parts

        # Sort by entropy
        II = np.argsort(H)

        self._parts = self._parts[II]

        self._num_parts = II.shape[0]
        #self._train_info['entropy'] = H[II]

        return comps
    def _train_inner(self, flatpatches, shape, hier, depth=0):
        min_prob = self._settings.get('min_prob', 0.01)

        #if len(flatpatches) < 5:
            ## Put all of them in both
            #flatpatches.mean()

        mm = BernoulliMM(n_components=self._num_parts_per_layer, 
                         n_iter=20, 
                         tol=1e-15,
                         n_init=5, 
                         random_state=self._settings.get('em_seed', 0), 
                         #params='m',
                         min_prob=min_prob)
        mm.fit(flatpatches)
        logprob, resp = mm.eval(flatpatches)
        comps = resp.argmax(-1)

        counts = np.bincount(comps, minlength=self._num_parts_per_layer)

        if 0:
            from scipy.stats import scoreatpercentile
            lower0 = scoreatpercentile(resp[...,0].ravel(), 50)
            lower1 = scoreatpercentile(resp[...,1].ravel(), 50)
            lows = [lower0, lower1]
            hier_flatpatches = [flatpatches[resp[...,m] >= lows[m]] for m in xrange(self._num_parts_per_layer)]

            # Reset the means
            for m in xrange(self._num_parts_per_layer):
                mm.means_[m] = np.clip(hier_flatpatches[m].mean(0), min_prob, 1 - min_prob)
        else:
            hier_flatpatches = [flatpatches[comps == m] for m in xrange(self._num_parts_per_layer)]
        
        #if depth == 0:
            #print('counts', counts) 
        #print(depth, 'hier_flatpatches', map(np.shape, hier_flatpatches))

        all_parts = []
        all_weights = []
        all_counts = []
        all_hierarchies = []

        pp = mm.means_.reshape((self._num_parts_per_layer,) + shape)

        hier[depth].append(pp)

        if depth+1 < self._depth:
            # Iterate
            for m in xrange(self._num_parts_per_layer):
                parts, weights, counts0 = self._train_inner(hier_flatpatches[m], shape, hier, depth=depth+1)

                all_parts.append(parts)
                all_weights.append(weights)
                all_counts.append(counts0)
                #all_hierarchies.append(hierarchy_parts)

            flat_parts = np.concatenate(all_parts, axis=0)
            all_weights = np.concatenate(all_weights, axis=0)
            all_counts = np.concatenate(all_counts, axis=0)
            return flat_parts, all_weights, all_counts#, [pp] + all_hierarchies
        else:
            parts = mm.means_
            weights = mm.weights_
            #all_parts.reshape((self._num_parts,)+patches.shape[1:])
            return parts, weights, counts
    def train_from_samples(self, patches):
        #from pnet.latent_bernoulli_mm import LatentBernoulliMM
        min_prob = self._settings.get('min_prob', 0.01)

        kp_patches = patches.reshape((patches.shape[0], -1, patches.shape[-1]))
        flatpatches = kp_patches.reshape((kp_patches.shape[0], -1))

        hier = [[] for _ in xrange(self._depth)]

        all_parts, all_weights, counts  = self._train_inner(flatpatches, patches.shape[1:], hier)

        if 0:
            # Post-training

            # Warm-up the EM with the leaves
            F = self._num_parts_per_layer ** self._depth

            mm = BernoulliMM(n_components=F,
                             n_iter=20, 
                             tol=1e-15,
                             n_init=1, 
                             random_state=100+self._settings.get('em_seed', 0), 
                             init_params='',
                             #params='m',
                             min_prob=min_prob)

            mm.means_ = all_parts
            mm.weights_ = counts / counts.sum()
            mm.fit(flatpatches)
            logprob, resp = mm.eval(flatpatches)
            comps = resp.argmax(-1)

            #for d in xrange(self._depth):
            #import pdb; pdb.set_trace()

            for d in reversed(xrange(len(hier))):
                for j in xrange(len(hier[d])):
                    hier_dj = hier[d][j]
                    for k in xrange(len(hier_dj)):
                        if d == len(hier) - 1:
                            hier[d][j][k] = mm.means_[j*self._num_parts_per_layer + k].reshape(patches.shape[1:])
                        else:
                            hier[d][j][k] = hier[d+1][j*self._num_parts_per_layer + k].mean(0)

                        #hier[d][j][k] = [flatpatches[comps == f].mean(0).reshape((-1,) + patches.shape[1:])
                        #else:
                            #hier[d][j][k] += 


            # Now, update the entire tree with these new results

        if 0:
            def pprint(x):
                if isinstance(x, list):
                    return "[" + ", ".join(map(pprint, x)) + "]"
                elif isinstance(x, np.ndarray):
                    #return "array{}".format(x.shape)
                    return 'A'

            #print(pprint(parts_hierarchy))
            print(pprint(hier))

            #import pdb; pdb.set_trace()

            print(all_parts.shape)

        self._parts = all_parts.reshape((self._num_parts,)+patches.shape[1:])
        self._weights = all_weights 
        self._counts = counts
        #self._parts_hierarchy = parts_hierarchy
        self._hier = hier
        self._flathier = np.asarray(sum(hier, [])) 

        from scipy.special import logit

        if self._num_parts_per_layer == 2:
            self._w = logit(self._flathier[:,1]) - logit(self._flathier[:,0]) 
            self._constant_terms = np.apply_over_axes(np.sum, np.log(1 - self._flathier[:,1]) - np.log(1 - self._flathier[:,0]), [1, 2, 3])[:,0,0,0]

        if 0:
            # Train it again by initializing with this
            mm = BernoulliMM(n_components=self._num_parts, 
                             n_iter=20, 
                             tol=1e-15,
                             n_init=1, 
                             init_params='w',
                             random_state=self._settings.get('em_seed', 0), 
                             min_prob=min_prob)
            mm.means_ = self._parts.reshape((self._num_parts, -1))
            mm.fit(flatpatches)

            self._parts = mm.means_.reshape((self._num_parts,)+patches.shape[1:])
            self._weights = mm.weights_

        Hall = (self._parts * np.log(self._parts) + (1 - self._parts) * np.log(1 - self._parts))
        H = -np.apply_over_axes(np.mean, Hall, [1, 2, 3]).ravel()
        self._train_info['entropy'] = H
    def train_from_samples(self, patches):
        #from pnet.latent_bernoulli_mm import LatentBernoulliMM
        min_prob = self._settings.get('min_prob', 0.01)

        kp_patches = patches.reshape((patches.shape[0], -1, patches.shape[-1]))
        flatpatches = kp_patches.reshape((kp_patches.shape[0], -1))

        q = []

        models = []
        constant_terms = []
        constant_terms_unsummed = []

        tree = []# -123*np.ones((1000, 2), dtype=np.int64)
        cur_pos = 0

        s = 0
        cur_part_id = 0
        
        q.insert(0, (cur_pos, 0, flatpatches))
        s += 1
        #cur_pos += 1
        while q:
            p, depth, x = q.pop()

            model = np.clip(np.mean(x, 0), min_prob, 1 - min_prob)
            def entropy(x):
                return -(x * np.log2(x) + (1 - x) * np.log2(1 - x))
            H = np.mean(entropy(model))

            sc = self._settings.get('split_criterion', 'H')

            if len(x) < self._settings.get('min_samples_per_part', 20) or depth >= self._max_depth or \
               (sc == 'H' and H < self._settings.get('split_entropy', 0.30)):
                #tree[p,0] = -1
                #tree[p,1] = cur_part_id
                tree.append((-1, cur_part_id))
                cur_part_id += 1

            else:
                mm = BernoulliMM(n_components=self._num_parts_per_layer, 
                                 n_iter=self._settings.get('n_iter', 8), 
                                 tol=1e-15,
                                 n_init=self._settings.get('n_init', 1), # May improve a bit to increase this
                                 random_state=self._settings.get('em_seed', 0), 
                                 #params='m',
                                 min_prob=min_prob)
                mm.fit(x[:self._settings.get('traing_limit')])
                logprob, resp = mm.eval(x)
                comps = resp.argmax(-1)

                w = logit(mm.means_[1]) - logit(mm.means_[0])


                Hafter = np.mean(entropy(mm.means_[0])) * mm.weights_[0] + np.mean(entropy(mm.means_[1])) * mm.weights_[1]
                IG = H - Hafter

                if sc == 'IG' and IG < self._settings.get('min_information_gain', 0.05):
                    tree.append((-1, cur_part_id))
                    cur_part_id += 1
                else:
                    tree.append((len(models), s))
                    K_unsummed = np.log((1 - mm.means_[1]) / (1 - mm.means_[0]))
                    K = np.sum(K_unsummed)

                    models.append(w)
                    constant_terms.append(K)
                    constant_terms_unsummed.append(K_unsummed)
                    #tree[p,1] = s


                    q.insert(0, (s, depth+1, x[comps == 0]))
                    #cur_pos += 1
                    q.insert(0, (s+1, depth+1, x[comps == 1]))
                    #cur_pos += 1
                    s += 2

        shape = (len(models),) + patches.shape[1:]
        weights = np.asarray(models).reshape(shape)
        constant_terms = np.asarray(constant_terms)
        constant_terms_unsummed = np.asarray(constant_terms_unsummed).reshape(shape)
        tree = np.asarray(tree, dtype=np.int64)

        self._tree = tree
        self._num_parts = cur_part_id
        #print('num_parts', self._num_parts)
        self._w = weights 
        self._constant_terms = constant_terms

        supp_radius = self._settings.get('keypoint_suppress_radius', 0)
        if supp_radius > 0:

            NW = self._w.shape[0]
            max_indices = self._settings.get('keypoint_max', 1000)
            keypoints = np.zeros((NW, max_indices, 3), dtype=np.int64)
            kp_constant_terms = np.zeros(NW)
            num_keypoints = np.zeros(NW, dtype=np.int64)

            from gv.keypoints import get_key_points
            for k in xrange(NW):
                kps = get_key_points(self._w[k], suppress_radius=supp_radius, max_indices=max_indices)

                NK = len(kps)
                num_keypoints[k] = NK
                keypoints[k,:NK] = kps

                for kp in kps:
                    kp_constant_terms[k] += constant_terms_unsummed[k,kp[0],kp[1],kp[2]]
 
            self._keypoints = keypoints
            self._num_keypoints = num_keypoints
            self._keypoint_constant_terms = kp_constant_terms