Exemple #1
0
    def partial_EM(self, data, cond_muh_ijk, indices, weights=None, eps=1e-4, maxiter=10, verbose=0):
        (i, j, k) = indices
        converged = False
        previous_L = utilities.average(
            self.likelihood(data), weights=weights) / self.N
        mini_epochs = 0
        if verbose:
            print('Partial EM %s, L = %.3f' % (mini_epochs, previous_L))
        while not converged:
            if self.nature in ['Bernoulli', 'Spin']:
                f = np.dot(data, self.weights[[i, j, k], :].T)
            elif self.nature == 'Potts':
                f = cy_utilities.compute_output_C(data, self.weights[[i, j, k], :, :], np.zeros([
                                                  data.shape[0], 3], dtype=curr_float))

            tmp = f - self.logZ[np.newaxis, [i, j, k]]
            tmp -= tmp.max(-1)[:, np.newaxis]
            cond_muh = np.exp(tmp) * self.muh[np.newaxis, [i, j, k]]
            cond_muh /= cond_muh.sum(-1)[:, np.newaxis]
            cond_muh *= cond_muh_ijk[:, np.newaxis]

            self.muh[[i, j, k]] = utilities.average(cond_muh, weights=weights)
            self.cum_muh = np.cumsum(self.muh)
            self.gh[[i, j, k]] = np.log(self.muh[[i, j, k]])
            self.gh -= self.gh.mean()
            if self.nature == 'Bernoulli':
                self.cond_muv[[i, j, k]] = utilities.average_product(
                    cond_muh, data, mean1=True, weights=weights) / self.muh[[i, j, k], np.newaxis]
                self.weights[[i, j, k]] = np.log(
                    (self.cond_muv[[i, j, k]] + eps) / (1 - self.cond_muv[[i, j, k]] + eps))
                self.logZ[[i, j, k]] = np.logaddexp(
                    0, self.weights[[i, j, k]]).sum(-1)
            elif self.nature == 'Spin':
                self.cond_muv[[i, j, k]] = utilities.average_product(
                    cond_muh, data, mean1=True, weights=weights) / self.muh[[i, j, k], np.newaxis]
                self.weights[[i, j, k]] = 0.5 * np.log(
                    (1 + self.cond_muv[[i, j, k]] + eps) / (1 - self.cond_muv[[i, j, k]] + eps))
                self.logZ[[i, j, k]] = np.logaddexp(
                    self.weights[[i, j, k]], -self.weights[[i, j, k]]).sum(-1)
            elif self.nature == 'Potts':
                self.cond_muv[[i, j, k]] = utilities.average_product(
                    cond_muh, data, c2=self.n_c, mean1=True, weights=weights) / self.muh[[i, j, k], np.newaxis, np.newaxis]
                self.cum_cond_muv[[i, j, k]] = np.cumsum(
                    self.cond_muv[[i, j, k]], axis=-1)
                self.weights[[i, j, k]] = np.log(
                    self.cond_muv[[i, j, k]] + eps)
                self.weights[[i, j, k]] -= self.weights[[i, j, k]
                                                        ].mean(-1)[:, :, np.newaxis]
                self.logZ[[i, j, k]] = utilities.logsumexp(
                    self.weights[[i, j, k]], axis=-1).sum(-1)

            current_L = utilities.average(
                self.likelihood(data), weights=weights) / self.N
            mini_epochs += 1
            converged = (mini_epochs >= maxiter) | (
                np.abs(current_L - previous_L) < eps)
            previous_L = current_L.copy()
            if verbose:
                print('Partial EM %s, L = %.3f' % (mini_epochs, current_L))
        return current_L
Exemple #2
0
def couplings_gradients_h(W, X1_p, X1_n, X2_p, X2_n, n_c1, n_c2, l1=0, l1b=0, l1c=0, l2=0, l1_custom=None, l1b_custom=None, weights=None, weights_neg=None):
    update = utilities.average_product(X1_p, X2_p, c1=n_c1, c2=n_c2, mean1=True, mean2=False, weights=weights) - \
        utilities.average_product(
            X1_n, X2_n, c1=n_c1, c2=n_c2, mean1=False, mean2=True, weights=weights_neg)
    if l2 > 0:
        update -= l2 * W
    if l1 > 0:
        update -= l1 * np.sign(W)
    if l1b > 0:  # NOT SUPPORTED FOR POTTS
        if n_c2 > 1:  # Potts RBM.
            update -= l1b * \
                np.sign(W) * \
                np.abs(W).mean(-1).mean(-1)[:, np.newaxis, np.newaxis]
        else:
            update -= l1b * np.sign(W) * (np.abs(W).sum(1))[:, np.newaxis]
    if l1c > 0:  # NOT SUPPORTED FOR POTTS
        update -= l1c * np.sign(W) * ((np.abs(W).sum(1))**2)[:, np.newaxis]
    if l1_custom is not None:
        update -= l1_custom * np.sign(W)
    if l1b_custom is not None:
        update -= l1b_custom[0] * (l1b_custom[1] * np.abs(W)).mean(-1).mean(-1)[
            :, np.newaxis, np.newaxis] * np.sign(W)

    if weights is not None:
        update *= weights.sum() / X1_p.shape[0]
    return update
Exemple #3
0
    def minibatch_fit_symKL(self, data_PGM, PGM=None, data_MOI=None, F_PGM_dPGM=None, F_PGM_dMOI=None, F_MOI_dPGM=None, F_MOI_dMOI=None, cond_muh_dPGM=None, cond_muh_dMOI=None, weights=None):
        if data_MOI is None:
            data_MOI, _ = self.gen_data(data_PGM.shape[0])
        if F_PGM_dPGM is None:
            F_PGM_dPGM = PGM.free_energy(data_PGM)
        if F_PGM_dMOI is None:
            F_PGM_dMOI = PGM.free_energy(data_MOI)
        if (F_MOI_dPGM is None) | (cond_muh_dPGM is None):
            F_MOI_dPGM, cond_muh_dPGM = self.likelihood_and_expectation(
                data_PGM)
            F_MOI_dPGM *= -1
        if (F_MOI_dMOI is None) | (cond_muh_dMOI is None):
            F_MOI_dMOI, cond_muh_dMOI = self.likelihood_and_expectation(
                data_MOI)
            F_MOI_dMOI *= -1

        delta_lik = -F_PGM_dMOI + F_MOI_dMOI
        delta_lik -= delta_lik.mean()

        self.gradient = {}
        self.gradient['gh'] = utilities.average(
            cond_muh_dPGM, weights=weights) - self.muh + (delta_lik[:, np.newaxis] * cond_muh_dMOI).mean(0)
        if self.nature in ['Bernoulli', 'Spin']:
            self.gradient['weights'] = utilities.average_product(
                cond_muh_dPGM, data_PGM, mean1=True, weights=weights) + utilities.average_product(cond_muh_dMOI * delta_lik[:, np.newaxis], data_MOI, mean1=True)
            self.gradient['weights'] -= self.muh[:, np.newaxis] * self.cond_muv
        elif self.nature == 'Potts':
            self.gradient['weights'] = utilities.average_product(cond_muh_dPGM, data_PGM, mean1=True, c2=self.n_c, weights=weights) + utilities.average_product(
                cond_muh_dMOI * delta_lik[:, np.newaxis], data_MOI, mean1=True, c2=self.n_c)
            self.gradient['weights'] -= self.muh[:,
                                                 np.newaxis, np.newaxis] * self.cond_muv

        self.gh += self.learning_rate * self.gradient['gh']
        self.weights += self.learning_rate * self.gradient['weights']

        self.muh = np.exp(self.gh)
        self.muh /= self.muh.sum()
        self.cum_muh = np.cumsum(self.muh)
        if self.nature == 'Bernoulli':
            self.cond_muv = utilities.logistic(self.weights)
        elif self.nature == 'Spin':
            self.cond_muv = np.tanh(self.weights)
        elif self.nature == 'Potts':
            self.weights -= self.weights.mean(-1)[:, :, np.newaxis]
            self.cond_muv = np.exp(self.weights)
            self.cond_muv /= self.cond_muv.sum(-1)[:, :, np.newaxis]
            self.cum_cond_muv = np.cumsum(self.cond_muv, axis=-1)
        self.logpartition()
Exemple #4
0
    def split_merge_criterion(self, data, Cmax=5, weights=None):
        likelihood, cond_muh = self.likelihood_and_expectation(data)
        norm = np.sqrt(utilities.average(cond_muh**2, weights=weights))
        J_merge = utilities.average_product(
            cond_muh, cond_muh, weights=weights) / (1e-10 + norm[np.newaxis, :] * norm[:, np.newaxis])
        J_merge = np.triu(J_merge, 1)
        proposed_merge = np.argsort(J_merge.flatten())[::-1][:Cmax]
        proposed_merge = [(merge % self.M, merge // self.M)
                          for merge in proposed_merge]

        tmp = cond_muh / self.muh[np.newaxis, :]

        if weights is None:
            J_split = np.array(
                [utilities.average(likelihood, weights=tmp[:, m]) for m in range(self.M)])
        else:
            J_split = np.array([utilities.average(
                likelihood, weights=tmp[:, m] * weights) for m in range(self.M)])

        proposed_split = np.argsort(J_split)[:3]
        proposed_merge_split = []
        for merge1, merge2 in proposed_merge:
            if proposed_split[0] in [merge1, merge2]:
                if proposed_split[1] in [merge1, merge2]:
                    proposed_merge_split.append(
                        (merge1, merge2, proposed_split[2]))
                else:
                    proposed_merge_split.append(
                        (merge1, merge2, proposed_split[1]))
            else:
                proposed_merge_split.append(
                    (merge1, merge2, proposed_split[0]))
        return proposed_merge_split
Exemple #5
0
    def maximization(self, data, cond_muh, weights=None, eps=1e-6):
        self.muh = utilities.average(cond_muh, weights=weights)
        self.cum_muh = np.cumsum(self.muh)
        self.gh = np.log(self.muh)
        self.gh -= self.gh.mean()
        if self.nature == 'Bernoulli':
            self.cond_muv = utilities.average_product(
                cond_muh, data, mean1=True, weights=weights) / self.muh[:, np.newaxis]
            self.weights = np.log((self.cond_muv + eps) /
                                  (1 - self.cond_muv + eps))
        elif self.nature == 'Spin':
            self.cond_muv = utilities.average_product(
                cond_muh, data, mean1=True, weights=weights) / self.muh[:, np.newaxis]
            self.weights = 0.5 * \
                np.log((1 + self.cond_muv + eps) / (1 - self.cond_muv + eps))

        elif self.nature == 'Potts':
            self.cond_muv = utilities.average_product(
                cond_muh, data, c2=self.n_c, mean1=True, weights=weights) / self.muh[:, np.newaxis, np.newaxis]
            self.cum_cond_muv = np.cumsum(self.cond_muv, axis=-1)
            self.weights = np.log(self.cond_muv + eps)
            self.weights -= self.weights.mean(-1)[:, :, np.newaxis]
        self.logpartition()
Exemple #6
0
    def minibatch_fit(self, data, weights=None, eps=1e-5, update=True):
        h = self.expectation(data)
        self.muh = self.learning_rate * \
            utilities.average(h, weights=weights) + \
            (1 - self.learning_rate) * self.muh
        self.cum_muh = np.cumsum(self.muh)
        if update:
            self.gh = np.log(self.muh + eps)
            self.gh -= self.gh.mean()
        if self.nature == 'Bernoulli':
            self.muvh = self.learning_rate * \
                utilities.average_product(
                    h, data, weights=weights) + (1 - self.learning_rate) * self.muvh
            if update:
                self.cond_muv = self.muvh / (self.muh[:, np.newaxis])
                self.weights = np.log(
                    (self.cond_muv + eps) / (1 - self.cond_muv + eps))
        elif self.nature == 'Spin':
            self.muvh = self.learning_rate * \
                utilities.average(h, data, weights=weights) + \
                (1 - self.learning_rate) * self.muvh
            if update:
                self.cond_muv = self.muvh / self.muh[:, np.newaxis]
                self.weights = 0.5 * \
                    np.log((1 + self.cond_muv + eps) /
                           (1 - self.cond_muv + eps))
        else:
            self.muvh = self.learning_rate * utilities.average_product(
                h, data, c2=self.n_c, weights=weights) + (1 - self.learning_rate) * self.muvh
            if update:
                self.cond_muv = self.muvh / self.muh[:, np.newaxis, np.newaxis]
                self.weights = np.log(self.cond_muv + eps)
                self.weights -= self.weights.mean(-1)[:, :, np.newaxis]

        if update:
            self.logpartition()
Exemple #7
0
def calculate_error(RBM,
                    data_tr,
                    N_sequences=800000,
                    Nstep=10,
                    background=None):
    N = RBM.n_v
    q = RBM.n_cv

    # Check how moments are reproduced
    # Means
    mudata = RBM.mu_data  # empirical averages
    #datav, datah = RBM.gen_data(Nchains = int(100), Lchains = int(N_sequences/100), Nthermalize=int(500), background= background)
    datav, datah = RBM.gen_data(Nchains=int(100),
                                Lchains=int(N_sequences / 100),
                                Nthermalize=int(500))
    mugen = utilities.average(datav, c=q, weights=None)

    # Correlations
    covgen = utilities.average_product(
        datav, datav, c1=q,
        c2=q) - mugen[:, np.newaxis, :, np.newaxis] * mugen[np.newaxis, :,
                                                            np.newaxis, :]
    covdata = utilities.average_product(
        data_tr, data_tr, c1=q,
        c2=q) - mudata[:, np.newaxis, :, np.newaxis] * mudata[np.newaxis, :,
                                                              np.newaxis, :]
    fdata = utilities.average_product(data_tr, data_tr, c1=q, c2=q)

    #put to zero the diagonal elements of the covariance
    for i in range(N):
        covdata[i, i, :, :] = np.zeros((q, q))
        fdata[i, i, :, :] = np.zeros((q, q))
        covgen[i, i, :, :] = np.zeros((q, q))

    M = len(data_tr)
    maxp = float(1) / float(M)
    pp = 1
    ps = 0.00001  # pseudocount for fully conserved sites

    # error on frequency
    pp = 1
    errm = 0
    neffm = 0
    for i in range(N):
        for a in range(q):
            neffm += 1
            if mudata[i, a] < maxp:
                errm += np.power((mugen[i, a] - mudata[i, a]),
                                 2) / (float(1 - maxp) * float(maxp))
            else:
                if mudata[i, a] != 1.0:
                    errm += np.power(
                        (mugen[i, a] - mudata[i, a]),
                        2) / (float(1 - mudata[i, a]) * float(mudata[i, a]))
                else:
                    errm += np.power((mugen[i, a] - mudata[i, a]), 2) / (
                        float(1 - mudata[i, a] - ps) * float(mudata[i, a]))

    errmt = np.sqrt(float(1) / (float(neffm) * float(maxp)) * float(errm))
    # rigourously, there would be also the regularization term in the difference errm!

    # error on correlations
    errc = 0
    neffc = 0
    for i in range(N):
        for j in range(i + 1, N):
            for a in range(q):
                for b in range(a + 1, q):
                    neffc += 1
                    if covdata[i, j, a, b] < maxp:
                        den = np.power(
                            np.sqrt(float(1 - maxp) * float(maxp)) +
                            mudata[i, a] * np.sqrt(mudata[j, b] *
                                                   (1 - mudata[j, b])) +
                            mudata[j, b] * np.sqrt(mudata[i, a] *
                                                   (1 - mudata[i, a])), 2)
                        errc += np.power(
                            (covgen[i, j, a, b] - covdata[i, j, a, b]),
                            2) / float(den)
                    else:
                        den = np.power(
                            np.sqrt(
                                float(1 - fdata[i, j, a, b]) *
                                float(fdata[i, j, a, b])) +
                            mudata[i, a] * np.sqrt(mudata[j, b] *
                                                   (1 - mudata[j, b])) +
                            mudata[j, b] * np.sqrt(mudata[i, a] *
                                                   (1 - mudata[i, a])), 2)
                        errc += np.power(
                            (covgen[i, j, a, b] - covdata[i, j, a, b]),
                            2) / float(den)

    errct = np.sqrt(float(1) / (float(neffc) * float(maxp)) * float(errc))
    return (errmt, errct)
def assess_moment_matching(RBM,
                           data,
                           data_gen,
                           datah_gen=None,
                           weights=None,
                           weights_neg=None,
                           with_reg=True,
                           show=True):
    h_data = RBM.mean_hiddens(data)
    if datah_gen is not None:
        h_gen = datah_gen
    else:
        h_gen = RBM.mean_hiddens(data_gen)
    mu = utilities.average(data, c=RBM.n_cv, weights=weights)
    if datah_gen is not None:
        condmu_gen = RBM.mean_visibles(datah_gen)
        mu_gen = utilities.average(condmu_gen, weights=weights_neg)
    else:
        mu_gen = utilities.average(data_gen, c=RBM.n_cv, weights=weights_neg)

    mu_h = utilities.average(h_data, weights=weights)
    mu_h_gen = utilities.average(h_gen, weights=weights_neg)

    if RBM.n_cv > 1:
        cov_vh = utilities.average_product(
            h_data, data, c2=RBM.n_cv, weights=weights
        ) - mu[np.newaxis, :, :] * mu_h[:, np.newaxis, np.newaxis]
    else:
        cov_vh = utilities.average_product(
            h_data, data, c2=RBM.n_cv,
            weights=weights) - mu[np.newaxis, :] * mu_h[:, np.newaxis]

    if datah_gen is not None:
        if RBM.n_cv > 1:
            cov_vh_gen = utilities.average_product(
                datah_gen,
                condmu_gen,
                mean2=True,
                c2=RBM.n_cv,
                weights=weights_neg
            ) - mu_gen[np.newaxis, :, :] * mu_h_gen[:, np.newaxis, np.newaxis]
        else:
            cov_vh_gen = utilities.average_product(
                datah_gen,
                condmu_gen,
                mean2=True,
                c2=RBM.n_cv,
                weights=weights_neg
            ) - mu_gen[np.newaxis, :] * mu_h_gen[:, np.newaxis]
    else:
        if RBM.n_cv > 1:
            cov_vh_gen = utilities.average_product(
                h_gen, data_gen, c2=RBM.n_cv, weights=weights_neg
            ) - mu_gen[np.newaxis, :, :] * mu_h_gen[:, np.newaxis, np.newaxis]
        else:
            cov_vh_gen = utilities.average_product(
                h_gen, data_gen, c2=RBM.n_cv, weights=weights_neg
            ) - mu_gen[np.newaxis, :] * mu_h_gen[:, np.newaxis]

    if RBM.hidden == 'dReLU':
        I_data = RBM.vlayer.compute_output(data, RBM.weights)
        I_gen = RBM.vlayer.compute_output(data_gen, RBM.weights)

        mu_p_pos, mu_n_pos, mu2_p_pos, mu2_n_pos = RBM.hlayer.mean12_pm_from_inputs(
            I_data)
        mu_p_pos = utilities.average(mu_p_pos, weights=weights)
        mu_n_pos = utilities.average(mu_n_pos, weights=weights)
        mu2_p_pos = utilities.average(mu2_p_pos, weights=weights)
        mu2_n_pos = utilities.average(mu2_n_pos, weights=weights)

        mu_p_neg, mu_n_neg, mu2_p_neg, mu2_n_neg = RBM.hlayer.mean12_pm_from_inputs(
            I_gen)
        mu_p_neg = utilities.average(mu_p_neg, weights=weights_neg)
        mu_n_neg = utilities.average(mu_n_neg, weights=weights_neg)
        mu2_p_neg = utilities.average(mu2_p_neg, weights=weights_neg)
        mu2_n_neg = utilities.average(mu2_n_neg, weights=weights_neg)

        a = RBM.hlayer.gamma
        eta = RBM.hlayer.eta
        theta = RBM.hlayer.delta

        moment_theta = -mu_p_pos / np.sqrt(1 + eta) + mu_n_pos / np.sqrt(1 -
                                                                         eta)
        moment_theta_gen = -mu_p_neg / np.sqrt(1 + eta) + mu_n_neg / np.sqrt(
            1 - eta)
        moment_eta = 0.5 * a / (1 + eta)**2 * mu2_p_pos - 0.5 * a / (
            1 - eta)**2 * mu2_n_pos + theta / (
                2 * np.sqrt(1 + eta)**3) * mu_p_pos - theta / (
                    2 * np.sqrt(1 - eta)**3) * mu_n_pos
        moment_eta_gen = 0.5 * a / (1 + eta)**2 * mu2_p_neg - 0.5 * a / (
            1 - eta)**2 * mu2_n_neg + theta / (
                2 * np.sqrt(1 + eta)**3) * mu_p_neg - theta / (
                    2 * np.sqrt(1 - eta)**3) * mu_n_neg

        moment_theta *= -1
        moment_theta_gen *= -1
        moment_eta *= -1
        moment_eta_gen *= -1

    W = RBM.weights
    if with_reg:
        l2 = RBM.l2
        l1 = RBM.l1
        l1b = RBM.l1b
        l1c = RBM.l1c
        l1_custom = RBM.l1_custom
        l1b_custom = RBM.l1b_custom
        n_c2 = RBM.n_cv
        if l2 > 0:
            cov_vh_gen += l2 * W
        if l1 > 0:
            cov_vh_gen += l1 * np.sign(W)
        if l1b > 0:  # NOT SUPPORTED FOR POTTS
            if n_c2 > 1:  # Potts RBM.
                cov_vh_gen += l1b * np.sign(W) * np.abs(W).mean(-1).mean(
                    -1)[:, np.newaxis, np.newaxis]
            else:
                cov_vh_gen += l1b * np.sign(W) * (np.abs(W).sum(1))[:,
                                                                    np.newaxis]
        if l1c > 0:  # NOT SUPPORTED FOR POTTS
            cov_vh_gen += l1c * np.sign(W) * (
                (np.abs(W).sum(1))**2)[:, np.newaxis]

        if any([l1 > 0, l1b > 0, l1c > 0]):
            mask_cov = np.abs(W) > 1e-3
        else:
            mask_cov = np.ones(W.shape, dtype='bool')
    else:
        mask_cov = np.ones(W.shape, dtype='bool')

    if RBM.n_cv > 1:
        if RBM.n_cv == 21:
            list_aa = Proteins_utils.aa
        else:
            list_aa = Proteins_utils.aa[:-1]
        colors_template = np.array([
            matplotlib.colors.to_rgba(aa_color_scatter(letter))
            for letter in list_aa
        ])
        color = np.repeat(colors_template[np.newaxis, :, :],
                          data.shape[1],
                          axis=0).reshape([data.shape[1] * RBM.n_cv, 4])
    else:
        color = 'C0'

    s2 = 14

    if RBM.hidden == 'dReLU':
        fig, ax = plt.subplots(3, 2)
        fig.set_figheight(3 * 5)
        fig.set_figwidth(2 * 5)
    else:
        fig, ax = plt.subplots(2, 2)
        fig.set_figheight(2 * 5)
        fig.set_figwidth(2 * 5)
    clean_ax(ax[1, 1])

    ax_ = ax[0, 0]
    ax_.scatter(mu.flatten(), mu_gen.flatten(), c=color)
    ax_.plot([mu.min(), mu.max()], [mu.min(), mu.max()])
    ax_.set_xlabel(r'$<v_i>_d$', fontsize=s2)
    ax_.set_ylabel(r'$<v_i>_m$', fontsize=s2)
    r2_mu = np.corrcoef(mu.flatten(), mu_gen.flatten())[0, 1]**2
    error_mu = np.sqrt(((mu - mu_gen)**2 / (mu * (1 - mu) + 1e-4)).mean())
    mini = mu.min()
    maxi = mu.max()
    ax_.text(0.6 * maxi + 0.4 * mini,
             0.25 * maxi + 0.75 * mini,
             r'$R^2 = %.2f$' % r2_mu,
             fontsize=s2)
    ax_.text(0.6 * maxi + 0.4 * mini,
             0.35 * maxi + 0.65 * mini,
             r'$Err = %.2e$' % error_mu,
             fontsize=s2)
    ax_.set_title('Mean visibles', fontsize=s2)

    ax_ = ax[0, 1]
    ax_.scatter(mu_h, mu_h_gen)
    ax_.plot([mu_h.min(), mu_h.max()], [mu_h.min(), mu_h.max()])
    ax_.set_xlabel(r'$<h_\mu>_d$', fontsize=s2)
    ax_.set_ylabel(r'$<h_\mu>_m$', fontsize=s2)
    r2_muh = np.corrcoef(mu_h, mu_h_gen)[0, 1]**2
    error_muh = np.sqrt(((mu_h - mu_h_gen)**2).mean())
    mini = mu_h.min()
    maxi = mu_h.max()
    ax_.text(0.6 * maxi + 0.4 * mini,
             0.25 * maxi + 0.75 * mini,
             r'$R^2 = %.2f$' % r2_muh,
             fontsize=s2)
    ax_.text(0.6 * maxi + 0.4 * mini,
             0.35 * maxi + 0.65 * mini,
             r'$Err = %.2e$' % error_muh,
             fontsize=s2)
    ax_.set_title('Mean hiddens', fontsize=s2)

    ax_ = ax[1, 0]
    if RBM.n_cv > 1:
        color = np.repeat(np.repeat(colors_template[np.newaxis,
                                                    np.newaxis, :, :],
                                    RBM.n_h,
                                    axis=0),
                          data.shape[1],
                          axis=1).reshape([RBM.n_v * RBM.n_h * RBM.n_cv, 4])
        color = color[mask_cov.flatten()]
    else:
        color = 'C0'

    cov_vh = cov_vh[mask_cov].flatten()
    cov_vh_gen = cov_vh_gen[mask_cov].flatten()

    ax_.scatter(cov_vh, cov_vh_gen, c=color)
    ax_.plot([cov_vh.min(), cov_vh.max()], [cov_vh.min(), cov_vh.max()])
    ax_.set_xlabel(r'Cov$(v_i \;, h_\mu)_d$', fontsize=s2)
    ax_.set_ylabel(r'Cov$(v_i \;, h_\mu)_m + \nabla_{w_{\mu i}} \mathcal{R}$',
                   fontsize=s2)
    r2_vh = np.corrcoef(cov_vh, cov_vh_gen)[0, 1]**2
    error_vh = np.sqrt(((cov_vh - cov_vh_gen)**2).mean())
    mini = cov_vh.min()
    maxi = cov_vh.max()
    ax_.text(0.6 * maxi + 0.4 * mini,
             0.25 * maxi + 0.75 * mini,
             r'$R^2 = %.2f$' % r2_vh,
             fontsize=s2)
    ax_.text(0.6 * maxi + 0.4 * mini,
             0.35 * maxi + 0.65 * mini,
             r'$Err = %.2e$' % error_vh,
             fontsize=s2)
    ax_.set_title('Hiddens-Visibles correlations', fontsize=s2)

    if RBM.hidden == 'dReLU':
        ax_ = ax[2, 0]
        ax_.scatter(moment_theta, moment_theta_gen, c=theta)
        ax_.plot([moment_theta.min(), moment_theta.max()],
                 [moment_theta.min(), moment_theta.max()])
        ax_.set_xlabel(r'$<-\frac{\partial E}{\partial \theta}>_d$',
                       fontsize=s2)
        ax_.set_ylabel(r'$<-\frac{\partial E}{\partial \theta}>_m$',
                       fontsize=s2)
        r2_theta = np.corrcoef(moment_theta, moment_theta_gen)[0, 1]**2
        error_theta = np.sqrt(((moment_theta - moment_theta_gen)**2).mean())
        mini = moment_theta.min()
        maxi = moment_theta.max()
        ax_.text(0.6 * maxi + 0.4 * mini,
                 0.25 * maxi + 0.75 * mini,
                 r'$R^2 = %.2f$' % r2_theta,
                 fontsize=s2)
        ax_.text(0.6 * maxi + 0.4 * mini,
                 0.35 * maxi + 0.65 * mini,
                 r'$Err = %.2e$' % error_theta,
                 fontsize=s2)
        ax_.set_title('Moment theta', fontsize=s2)

        ax_ = ax[2, 1]
        ax_.scatter(moment_eta, moment_eta_gen, c=np.abs(eta))
        ax_.plot([moment_eta.min(), moment_eta.max()],
                 [moment_eta.min(), moment_eta.max()])
        ax_.set_xlabel(r'$<-\frac{\partial E}{\partial \eta}>_d$', fontsize=s2)
        ax_.set_ylabel(r'$<-\frac{\partial E}{\partial \eta}>_m$', fontsize=s2)
        r2_eta = np.corrcoef(moment_eta, moment_eta_gen)[0, 1]**2
        error_eta = np.sqrt(((moment_eta - moment_eta_gen)**2).mean())
        mini = moment_eta.min()
        maxi = moment_eta.max()
        ax_.text(0.6 * maxi + 0.4 * mini,
                 0.25 * maxi + 0.75 * mini,
                 r'$R^2 = %.2f$' % r2_eta,
                 fontsize=s2)
        ax_.text(0.6 * maxi + 0.4 * mini,
                 0.35 * maxi + 0.65 * mini,
                 r'$Err = %.2e$' % error_eta,
                 fontsize=s2)
        ax_.set_title('Moment eta', fontsize=s2)

    plt.tight_layout()
    if show:
        fig.show()

    if RBM.hidden == 'dReLU':
        errors = [error_mu, error_muh, error_vh, error_theta, error_eta]
        r2s = [r2_mu, r2_muh, r2_vh, r2_theta, r2_eta]
    else:
        errors = [error_mu, error_muh, error_vh]
        r2s = [r2_mu, r2_muh, r2_vh]

    return fig, errors, r2s
Exemple #9
0
def get_cross_derivatives_ReLU(V_pos, psi_pos, hlayer, n_cv, weights=None):
    db_dw = average(V_pos, c=n_cv, weights=weights)
    a = hlayer.gamma[np.newaxis, :]
    theta = hlayer.delta[np.newaxis, :]
    b = hlayer.theta[np.newaxis, :]

    psi = psi_pos

    psi_plus = (-(psi - b) + theta) / np.sqrt(a)
    psi_minus = ((psi - b) + theta) / np.sqrt(a)

    Phi_plus = erf_times_gauss(psi_plus)
    Phi_minus = erf_times_gauss(psi_minus)

    p_plus = 1 / (1 + Phi_minus / Phi_plus)
    p_minus = 1 - p_plus

    e = (psi - b) - theta * (p_plus - p_minus)
    v = p_plus * p_minus * (2 * theta / np.sqrt(a)) * \
        (2 * theta / np.sqrt(a) - 1 / Phi_plus - 1 / Phi_minus)

    dpsi_plus_dpsi = -1 / np.sqrt(a)
    dpsi_minus_dpsi = 1 / np.sqrt(a)
    dpsi_plus_dtheta = 1 / np.sqrt(a)
    dpsi_minus_dtheta = 1 / np.sqrt(a)
    dpsi_plus_da = -1.0 / (2 * a) * psi_plus
    dpsi_minus_da = -1.0 / (2 * a) * psi_minus

    d2psi_plus_dadpsi = 0.5 / np.sqrt(a**3)
    d2psi_plus_dthetadpsi = 0
    d2psi_minus_dadpsi = -0.5 / np.sqrt(a**3)
    d2psi_minus_dthetadpsi = 0

    dp_plus_dpsi = p_plus * p_minus * \
        ((psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi -
         (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi)
    dp_plus_dtheta = p_plus * p_minus * \
        ((psi_plus - 1 / Phi_plus) * dpsi_plus_dtheta -
         (psi_minus - 1 / Phi_minus) * dpsi_minus_dtheta)
    dp_plus_da = p_plus * p_minus * \
        ((psi_plus - 1 / Phi_plus) * dpsi_plus_da -
         (psi_minus - 1 / Phi_minus) * dpsi_minus_da)

    d2p_plus_dpsi2 = -(p_plus - p_minus) * p_plus * p_minus * ((psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi - (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi)**2 \
        + p_plus * p_minus * ((dpsi_plus_dpsi)**2 * (1 + (psi_plus - 1 / Phi_plus) / Phi_plus) - (
            dpsi_minus_dpsi)**2 * (1 + (psi_minus - 1 / Phi_minus) / Phi_minus))

    d2p_plus_dadpsi = -(p_plus - p_minus) * ((psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi - (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi) * (dp_plus_da)\
        + p_plus * p_minus * ((dpsi_plus_dpsi * dpsi_plus_da) * (1 + (psi_plus - 1 / Phi_plus) / Phi_plus) - (dpsi_minus_dpsi * dpsi_minus_da) * (1 + (psi_minus - 1 / Phi_minus) / Phi_minus)
                              + (d2psi_plus_dadpsi) * (psi_plus - 1 / Phi_plus) - (d2psi_minus_dadpsi) * (psi_minus - 1 / Phi_minus))

    d2p_plus_dthetadpsi = -(p_plus - p_minus) * ((psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi - (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi) * (dp_plus_dtheta)\
        + p_plus * p_minus * ((dpsi_plus_dpsi * dpsi_plus_dtheta) * (1 + (psi_plus - 1 / Phi_plus) / Phi_plus) - (dpsi_minus_dpsi * dpsi_minus_dtheta) * (1 + (psi_minus - 1 / Phi_minus) / Phi_minus)
                              + (d2psi_plus_dthetadpsi) * (psi_plus - 1 / Phi_plus) - (d2psi_minus_dthetadpsi) * (psi_minus - 1 / Phi_minus))

    # dlogZ_dpsi = (p_plus * (psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi +
    #               p_minus * (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi)
    # dlogZ_dtheta = (p_plus * (psi_plus - 1 / Phi_plus) * dpsi_plus_dtheta +
    #                 p_minus * (psi_minus - 1 / Phi_minus) * dpsi_minus_dtheta)
    # dlogZ_da = (p_plus * (psi_plus - 1 / Phi_plus) * dpsi_plus_da +
    #             p_minus * (psi_minus - 1 / Phi_minus) * dpsi_minus_da)

    de_dpsi = (1 + v)
    de_db = -de_dpsi
    de_da = 2 * (-theta) * dp_plus_da
    de_dtheta = -(p_plus - p_minus) + 2 * (-theta) * dp_plus_dtheta

    dv_dpsi = 2 * (-theta) * d2p_plus_dpsi2
    dv_db = -dv_dpsi

    dv_da = +2 * (-theta) * d2p_plus_dadpsi

    dv_dtheta = - 2 * dp_plus_dpsi \
        + 2 * (- theta) * d2p_plus_dthetadpsi

    var_e = average(e**2, weights=weights) - average(e, weights=weights)**2
    mean_v = average(v, weights=weights)

    dmean_v_da = average(dv_da, weights=weights)
    dmean_v_db = average(dv_db, weights=weights)
    dmean_v_dtheta = average(dv_dtheta, weights=weights)

    dvar_e_da = 2 * (
        average(e * de_da, weights=weights) -
        average(e, weights=weights) * average(de_da, weights=weights))
    dvar_e_db = 2 * (
        average(e * de_db, weights=weights) -
        average(e, weights=weights) * average(de_db, weights=weights))
    dvar_e_dtheta = 2 * (
        average(e * de_dtheta, weights=weights) -
        average(e, weights=weights) * average(de_dtheta, weights=weights))

    tmp = np.sqrt((1 + mean_v)**2 + 4 * var_e)
    da_db = (dvar_e_db + 0.5 * dmean_v_db * (1 + mean_v + tmp)) / \
        (tmp - dvar_e_da - 0.5 * dmean_v_da * (1 + mean_v + tmp))
    da_dtheta = (dvar_e_dtheta + 0.5 * dmean_v_dtheta *
                 (1 + mean_v + tmp)) / (tmp - dvar_e_da - 0.5 * dmean_v_da *
                                        (1 + mean_v + tmp))

    dmean_v_dw = average_product(dv_dpsi,
                                 V_pos,
                                 c1=1,
                                 c2=n_cv,
                                 weights=weights)

    if n_cv > 1:
        dvar_e_dw = 2 * (
            average_product(e * de_dpsi, V_pos, c1=1, c2=n_cv, weights=weights)
            - average(e, weights=weights)[:, np.newaxis, np.newaxis] *
            average_product(de_dpsi, V_pos, c1=1, c2=n_cv, weights=weights))
        da_dw = (dvar_e_dw + 0.5 * dmean_v_dw *
                 (1 + mean_v + tmp)[:, np.newaxis, np.newaxis]) / (
                     tmp - dvar_e_da - 0.5 * dmean_v_da *
                     (1 + mean_v + tmp))[:, np.newaxis, np.newaxis]

    else:
        dvar_e_dw = 2 * (
            average_product(e * de_dpsi, V_pos, c1=1, c2=1, weights=weights) -
            average(e, weights=weights)[:, np.newaxis] *
            average_product(de_dpsi, V_pos, c1=1, c2=1, weights=weights))
        da_dw = (dvar_e_dw + 0.5 * dmean_v_dw *
                 (1 + mean_v + tmp)[:, np.newaxis]) / (
                     tmp - dvar_e_da - 0.5 * dmean_v_da *
                     (1 + mean_v + tmp))[:, np.newaxis]

    return db_dw, da_db, da_dtheta, da_dw
Exemple #10
0
def get_cross_derivatives_ReLU_plus(V_pos,
                                    psi_pos,
                                    hlayer,
                                    n_cv,
                                    weights=None):
    db_dw = average(V_pos, c=n_cv, weights=weights)

    a = hlayer.gamma[np.newaxis, :]
    b = hlayer.theta[np.newaxis, :]

    psi = psi_pos
    psi_plus = -(psi - b) / np.sqrt(a)

    Phi_plus = erf_times_gauss(psi_plus)

    e = (psi - b) + np.sqrt(a) / Phi_plus
    v = (psi_plus - 1 / Phi_plus) / Phi_plus

    dpsi_plus_dpsi = -1 / np.sqrt(a)
    dpsi_plus_da = -1.0 / (2 * a) * psi_plus

    de_dpsi = 1 + v
    de_db = -de_dpsi
    de_da = np.sqrt(a) * (1.0 / (2 * a * Phi_plus) -
                          (psi_plus - 1 / Phi_plus) / Phi_plus * dpsi_plus_da)

    dv_dpsi = dpsi_plus_dpsi * \
        (1 + psi_plus / Phi_plus - 1 / Phi_plus **
         2 - (psi_plus - 1 / Phi_plus)**2) / Phi_plus
    dv_db = -dv_dpsi
    dv_da = dpsi_plus_da * (1 + psi_plus / Phi_plus - 1 / Phi_plus**2 -
                            (psi_plus - 1 / Phi_plus)**2) / Phi_plus

    var_e = average(e**2, weights=weights) - average(e, weights=weights)**2
    mean_v = average(v, weights=weights)

    dmean_v_da = average(dv_da, weights=weights)
    dmean_v_db = average(dv_db, weights=weights)

    dvar_e_da = 2 * (
        average(e * de_da, weights=weights) -
        average(e, weights=weights) * average(de_da, weights=weights))
    dvar_e_db = 2 * (
        average(e * de_db, weights=weights) -
        average(e, weights=weights) * average(de_db, weights=weights))

    tmp = np.sqrt((1 + mean_v)**2 + 4 * var_e)
    denominator = (tmp - dvar_e_da - 0.5 * dmean_v_da * (1 + mean_v + tmp))
    denominator = np.maximum(denominator, 0.5)  # For numerical stability.

    da_db = (dvar_e_db + 0.5 * dmean_v_db * (1 + mean_v + tmp)) / denominator

    dmean_v_dw = average_product(dv_dpsi,
                                 V_pos,
                                 c1=1,
                                 c2=n_cv,
                                 weights=weights)

    if n_cv > 1:
        dvar_e_dw = 2 * (
            average_product(e * de_dpsi, V_pos, c1=1, c2=n_cv, weights=weights)
            - average(e, weights=weights)[:, np.newaxis, np.newaxis] *
            average_product(de_dpsi, V_pos, c1=1, c2=n_cv, weights=weights))
        da_dw = (dvar_e_dw + 0.5 * dmean_v_dw *
                 (1 + mean_v + tmp)[:, np.newaxis, np.newaxis]
                 ) / denominator[:, np.newaxis, np.newaxis]

    else:
        dvar_e_dw = 2 * (
            average_product(e * de_dpsi, V_pos, c1=1, c2=1, weights=weights) -
            average(e, weights=weights)[:, np.newaxis] *
            average_product(de_dpsi, V_pos, c1=1, c2=1, weights=weights))
        da_dw = (dvar_e_dw + 0.5 * dmean_v_dw *
                 (1 + mean_v + tmp)[:, np.newaxis]) / denominator[:,
                                                                  np.newaxis]

    return db_dw, da_db, da_dw
Exemple #11
0
    def minibatch_fit(self, X, Y, weights=None):
        self.count_updates += 1
        grad = {}
        I = self.input_layer.compute_output(X, self.weights, direction='down')

        prediction = self.output_layer.mean_from_inputs(I)
        grad['weights'] = self.moments_XY - utilities.average_product(
            X,
            prediction,
            c1=self.n_cin,
            c2=self.n_cout,
            mean2=True,
            weights=weights)
        if self.nature in ['Bernoulli', 'Spin', 'Potts']:
            grad['output_layer'] = self.output_layer.internal_gradients(
                self.moments_Y,
                prediction,
                value='moments',
                value_neg='mean',
                weights_neg=weights)
        else:
            grad['output_layer'] = self.output_layer.internal_gradients(
                self.moments_Y,
                I,
                value='moments',
                value_neg='input',
                weights_neg=weights)

        for regtype, regtarget, regvalue in self.regularizers:
            if regtarget == 'weights':
                target_gradient = grad['weights']
                target = self.weights
            else:
                target_gradient = grad['output_layer'][regtarget]
                target = self.output_layer.__dict__[regtarget]
            if regtype == 'l1':
                target_gradient -= regvalue * np.sign(target)
            elif regtype == 'l2':
                target_gradient -= regvalue * target
            else:
                print(regtype, 'not supported')

        for key, gradient in grad['output_layer'].items():
            if self.output_layer.do_grad_updates[key]:
                if self.optimizer == 'SGD':
                    self.output_layer.__dict__[
                        key] += self.learning_rate * gradient
                elif self.optimizer == 'ADAM':
                    self.gradient_moment1['output_layer'][key] *= self.beta1
                    self.gradient_moment1['output_layer'][key] += (
                        1 - self.beta1) * gradient
                    self.gradient_moment2['output_layer'][key] *= self.beta2
                    self.gradient_moment2['output_layer'][key] += (
                        1 - self.beta2) * gradient**2

                    self.output_layer.__dict__[key] += self.learning_rate / (
                        1 - self.beta1
                    ) * (self.gradient_moment1['output_layer'][key] /
                         (1 - self.beta1**self.count_updates)) / (
                             self.epsilon + np.sqrt(
                                 self.gradient_moment2['output_layer'][key] /
                                 (1 - self.beta2**self.count_updates)))

        if self.optimizer == 'SGD':
            self.weights += self.learning_rate * grad['weights']
        elif self.optimizer == 'ADAM':
            self.gradient_moment1['weights'] *= self.beta1
            self.gradient_moment1['weights'] += (1 -
                                                 self.beta1) * grad['weights']
            self.gradient_moment2['weights'] *= self.beta2
            self.gradient_moment2['weights'] += (
                1 - self.beta2) * grad['weights']**2

            self.weights += self.learning_rate / (1 - self.beta1) * (
                self.gradient_moment1['weights'] /
                (1 - self.beta1**self.count_updates)
            ) / (self.epsilon + np.sqrt(self.gradient_moment2['weights'] /
                                        (1 - self.beta2**self.count_updates)))

        if self.symmetric:
            if self.n_cout > 1:
                self.weights += np.swapaxes(np.swapaxes(self.weights, 0, 1), 2,
                                            3)
                self.weights /= 2
            else:
                self.weights += self.weights.T
                self.weights /= 2
        if self.zero_diag:
            self.weights[np.arange(self.Nout), np.arange(self.Nout)] *= 0
        return
Exemple #12
0
    def fit(self,
            X,
            Y,
            weights=None,
            batch_size=100,
            learning_rate=None,
            lr_final=None,
            lr_decay=True,
            decay_after=0.5,
            extra_params=None,
            optimizer='ADAM',
            n_iter=10,
            verbose=1,
            regularizers=[]):

        self.batch_size = batch_size
        self.optimizer = optimizer
        self.n_iter = n_iter
        if self.n_iter <= 1:
            lr_decay = False
        if learning_rate is None:
            if self.optimizer == 'SGD':
                learning_rate = 0.01
            elif self.optimizer == 'ADAM':
                learning_rate = 5e-4
            else:
                print('Need to specify learning rate for optimizer.')
        if self.optimizer == 'ADAM':
            if extra_params is None:
                extra_params = [0.9, 0.99, 1e-3]
            self.beta1 = extra_params[0]
            self.beta2 = extra_params[1]
            self.epsilon = extra_params[2]
            if self.n_cout > 1:
                out0 = np.zeros([1, self.Nout, self.n_cout], dtype=curr_float)
            else:
                out0 = np.zeros([1, self.Nout], dtype=curr_float)

            grad = {
                'weights':
                np.zeros_like(self.weights),
                'output_layer':
                self.output_layer.internal_gradients(out0,
                                                     out0,
                                                     value='input',
                                                     value_neg='input')
            }

            for key in grad['output_layer'].keys():
                grad['output_layer'][key] *= 0

            self.gradient_moment1 = copy.deepcopy(grad)
            self.gradient_moment2 = copy.deepcopy(grad)

        self.learning_rate_init = copy.copy(learning_rate)
        self.learning_rate = learning_rate
        self.lr_decay = lr_decay
        if self.lr_decay:
            self.decay_after = decay_after
            self.start_decay = int(self.n_iter * self.decay_after)
            if lr_final is None:
                self.lr_final = 1e-2 * self.learning_rate
            else:
                self.lr_final = lr_final
            self.decay_gamma = (float(self.lr_final) /
                                float(self.learning_rate))**(
                                    1 / float(self.n_iter *
                                              (1 - self.decay_after)))
        else:
            self.decay_gamma = 1
        self.regularizers = regularizers

        n_samples = X.shape[0]
        n_batches = int(np.ceil(float(n_samples) / self.batch_size))
        batch_slices = list(
            utilities.gen_even_slices(n_batches * self.batch_size, n_batches,
                                      n_samples))

        X = np.asarray(X, dtype=self.input_layer.type, order='c')
        Y = np.asarray(Y, dtype=self.output_layer.type, order='c')
        if weights is not None:
            weights = weights.astype(curr_float)

        self.moments_Y = self.output_layer.get_moments(Y,
                                                       weights=weights,
                                                       value='data')
        self.moments_XY = utilities.average_product(X,
                                                    Y,
                                                    c1=self.n_cin,
                                                    c2=self.n_cout,
                                                    mean1=False,
                                                    mean2=False,
                                                    weights=weights)

        self.count_updates = 0

        for epoch in range(1, n_iter + 1):
            if verbose:
                begin = time.time()
            if self.lr_decay:
                if (epoch > self.start_decay):
                    self.learning_rate *= self.decay_gamma

            permutation = np.argsort(np.random.randn(n_samples))
            X = X[permutation, :]
            Y = Y[permutation, :]
            if weights is not None:
                weights = weights[permutation]

            if verbose:
                print('Starting epoch %s' % (epoch))
            for batch_slice in batch_slices:
                if weights is not None:
                    self.minibatch_fit(X[batch_slice],
                                       Y[batch_slice],
                                       weights=weights[batch_slice])
                else:
                    self.minibatch_fit(X[batch_slice],
                                       Y[batch_slice],
                                       weights=None)

            if verbose:
                end = time.time()
                lik = utilities.average(self.likelihood(X, Y), weights=weights)
                regularization = 0
                for regtype, regtarget, regvalue in self.regularizers:
                    if regtarget == 'weights':
                        target = self.weights
                    else:
                        target = self.output_layer.__dict__[regtarget]
                    if regtype == 'l1':
                        regularization += (regvalue * np.abs(target)).sum()
                    elif regtype == 'l2':
                        regularization += 0.5 * (regvalue * target**2).sum()
                    else:
                        print(regtype, 'not supported')
                    regularization /= self.Nout
                message = "Iteration %d, time = %.2fs, likelihood = %.2f, regularization  = %.2e, loss = %.2f" % (
                    epoch, end - begin, lik, regularization,
                    -lik + regularization)
                print(message)
        return 'done'
Exemple #13
0
    def fit(self,
            data,
            batch_size=100,
            nchains=100,
            learning_rate=None,
            extra_params=None,
            init='independent',
            optimizer='SGD',
            N_PT=1,
            N_MC=1,
            n_iter=10,
            lr_decay=True,
            lr_final=None,
            decay_after=0.5,
            l1=0,
            l1b=0,
            l1c=0,
            l2=0,
            l2_fields=0,
            no_fields=False,
            batch_norm=False,
            update_betas=None,
            record_acceptance=None,
            epsilon=1e-6,
            verbose=1,
            record=[],
            record_interval=100,
            p=[1, 0, 0],
            pseudo_count=0,
            weights=None):

        self.nchains = nchains
        self.optimizer = optimizer
        self.record_swaps = False
        self.batch_norm = batch_norm
        self.layer.batch_norm = batch_norm

        self.n_iter = n_iter

        if learning_rate is None:
            if self.nature in ['Bernoulli', 'Spin', 'Potts']:
                learning_rate = 0.1
            else:
                learning_rate = 0.01

            if self.optimizer == 'ADAM':
                learning_rate *= 0.1

        self.learning_rate = learning_rate
        self.lr_decay = lr_decay
        if self.lr_decay:
            self.decay_after = decay_after
            self.start_decay = self.n_iter * self.decay_after
            if lr_final is None:
                self.lr_final = 1e-2 * self.learning_rate
            else:
                self.lr_final = lr_final
            self.decay_gamma = (float(self.lr_final) /
                                float(self.learning_rate))**(
                                    1 / float(self.n_iter *
                                              (1 - self.decay_after)))

        self.gradient = self.initialize_gradient_dictionary()

        if self.optimizer == 'momentum':
            if extra_params is None:
                extra_params = 0.9
            self.momentum = extra_params
            self.previous_update = self.initialize_gradient_dictionary()

        elif self.optimizer == 'ADAM':
            if extra_params is None:
                extra_params = [0.9, 0.999, 1e-8]
            self.beta1 = extra_params[0]
            self.beta2 = extra_params[1]
            self.epsilon = extra_params[2]

            self.gradient_moment1 = self.initialize_gradient_dictionary()
            self.gradient_moment2 = self.initialize_gradient_dictionary()

        if weights is not None:
            weights = np.asarray(weights, dtype=float)

        mean = utilities.average(data, c=self.n_c, weights=weights)
        covariance = utilities.average_product(data,
                                               data,
                                               c1=self.n_c,
                                               c2=self.n_c,
                                               weights=weights)
        if pseudo_count > 0:
            p = data.shape[0] / float(data.shape[0] + pseudo_count)
            covariance = p**2 * covariance + p * \
                (1 - p) * (mean[np.newaxis, :, np.newaxis, :] * mean[:,
                                                                     np.newaxis, :, np.newaxis]) / self.n_c + (1 - p)**2 / self.n_c**2
            mean = p * mean + (1 - p) / self.n_c

        iter_per_epoch = data.shape[0] // batch_size
        if init != 'previous':
            norm_init = 0
            self.init_couplings(norm_init)
            if init == 'independent':
                self.layer.init_params_from_data(data,
                                                 eps=epsilon,
                                                 value='data')

        self.N_PT = N_PT
        self.N_MC = N_MC

        self.l1 = l1
        self.l1b = l1b
        self.l1c = l1c
        self.l2 = l2
        self.tmp_l2_fields = l2_fields
        self.no_fields = no_fields

        if self.N_PT > 1:
            if record_acceptance == None:
                record_acceptance = True
            self.record_acceptance = record_acceptance

            if update_betas == None:
                update_betas = True

            self._update_betas = update_betas

            if self.record_acceptance:
                self.mavar_gamma = 0.95
                self.acceptance_rates = np.zeros(N_PT - 1)
                self.mav_acceptance_rates = np.zeros(N_PT - 1)
            self.count_swaps = 0

            if self._update_betas:
                record_acceptance = True
                self.update_betas_lr = 0.1
                self.update_betas_lr_decay = 1

            if self._update_betas | (not hasattr(self, 'betas')):
                self.betas = np.arange(N_PT) / float(N_PT - 1)
                self.betas = self.betas[::-1]
            if (len(self.betas) != N_PT):
                self.betas = np.arange(N_PT) / float(N_PT - 1)
                self.betas = self.betas[::-1]

        if self.nature == 'Potts':
            (self.fantasy_x,
             self.fantasy_fields_eff) = self.layer.sample_from_inputs(np.zeros(
                 [self.N_PT * self.nchains, self.N, self.n_c]),
                                                                      beta=0)
        else:
            (self.fantasy_x,
             self.fantasy_fields_eff) = self.layer.sample_from_inputs(np.zeros(
                 [self.N_PT * self.nchains, self.N]),
                                                                      beta=0)
        if self.N_PT > 1:
            self.fantasy_x = self.fantasy_x.reshape(
                [self.N_PT, self.nchains, self.N])
            if self.nature == 'Potts':
                self.fantasy_fields_eff = self.fantasy_fields_eff.reshape(
                    [self.N_PT, self.nchains, self.N, self.n_c])
            else:
                self.fantasy_fields_eff = self.fantasy_fields_eff.reshape(
                    [self.N_PT, self.nchains, self.N])
            self.fantasy_E = np.zeros([self.N_PT, self.nchains])

        self.count_updates = 0
        if verbose:
            if weights is not None:
                lik = (self.pseudo_likelihood(data) *
                       weights).sum() / weights.sum()
            else:
                lik = self.pseudo_likelihood(data).mean()
            print('Iteration number 0, pseudo-likelihood: %.2f' % lik)

        result = {}
        if 'J' in record:
            result['J'] = []
        if 'F' in record:
            result['F'] = []

        count = 0

        for epoch in range(1, n_iter + 1):
            if verbose:
                begin = time.time()
            if self.lr_decay:
                if (epoch > self.start_decay):
                    self.learning_rate *= self.decay_gamma

            print('Starting epoch %s' % (epoch))
            for _ in range(iter_per_epoch):
                self.minibatch_fit(mean, covariance)

                if (count % record_interval == 0):
                    if 'J' in record:
                        result['J'].append(self.layer.couplings.copy())
                    if 'F' in record:
                        result['F'].append(self.layer.fields.copy())

                count += 1

            if verbose:
                end = time.time()
                if weights is not None:
                    lik = (self.pseudo_likelihood(data) *
                           weights).sum() / weights.sum()
                else:
                    lik = self.pseudo_likelihood(data).mean()

                print("[%s] Iteration %d, pseudo-likelihood = %.2f,"
                      " time = %.2fs" %
                      (type(self).__name__, epoch, lik, end - begin))

        return result
Exemple #14
0
    def fit(self,
            data,
            weights=None,
            pseudo_count=1e-4,
            verbose=1,
            zero_diag=True):

        fi = utilities.average(data, c=self.n_c, weights=weights)
        fij = utilities.average_product(data,
                                        data,
                                        c1=self.n_c,
                                        c2=self.n_c,
                                        weights=weights)
        for i in range(self.N):
            fij[i, i] = np.diag(fi[i])

        fi_PC = (1 - pseudo_count) * fi + pseudo_count / float(self.n_c)
        fij_PC = (1 - pseudo_count) * fij + pseudo_count / float(self.n_c)**2

        for i in range(self.N):
            fij_PC[i, i] = np.diag(fi_PC[i])

        Cij = fij_PC - fi_PC[
            np.newaxis, :, np.newaxis, :] * fi_PC[:, np.newaxis, :, np.newaxis]

        D = np.zeros([self.N, self.n_c - 1, self.n_c - 1])
        invD = np.zeros([self.N, self.n_c - 1, self.n_c - 1])

        for n in range(self.N):
            D[n] = scipy.linalg.sqrtm(Cij[n, n, :-1, :-1])
            invD[n] = np.linalg.inv(D[n])

        Gamma = np.zeros([self.N, self.n_c - 1, self.N, self.n_c - 1])
        for n1 in range(self.N):
            for n2 in range(self.N):
                Gamma[n1, :,
                      n2, :] = np.dot(invD[n1],
                                      np.dot(Cij[n1, n2, :-1, :-1], invD[n2]))

        Gamma_bin = Gamma.reshape(
            [self.N * (self.n_c - 1), self.N * (self.n_c - 1)])
        Gamma_bin = (Gamma_bin + Gamma_bin.T) / 2
        lam, v = np.linalg.eigh(Gamma_bin)
        order = np.argsort(lam)[::-1]

        v_ordered = np.rollaxis(
            v.reshape([self.N, self.n_c - 1, self.N * (self.n_c - 1)]), 2,
            0)[order, :, :]
        lam_ordered = lam[order]
        DeltaL = 0.5 * (lam_ordered - 1 - np.log(lam_ordered))
        xi = np.zeros(v_ordered.shape)
        for n in range(self.N):
            xi[:, n, :] = np.dot(v_ordered[:, n, :], invD[n])
        xi = np.sqrt(np.abs(1 - 1 / lam_ordered))[:, np.newaxis,
                                                  np.newaxis] * xi

        xi = np.concatenate(
            (xi, np.zeros([self.N * (self.n_c - 1), self.N, 1])),
            axis=2)  # Write in zero-sum gauge.
        xi -= xi.mean(-1)[:, :, np.newaxis]
        top_M_contrib = np.argsort(DeltaL)[::-1][:self.M]

        self.xi = xi[top_M_contrib]
        self.lam = lam_ordered[top_M_contrib]
        self.DeltaL = DeltaL[top_M_contrib]

        couplings = np.tensordot(
            self.xi[self.lam > 1], self.xi[self.lam > 1], axes=[
                (0), (0)
            ]) - np.tensordot(
                self.xi[self.lam < 1], self.xi[self.lam < 1], axes=[(0), (0)])
        couplings = np.asarray(np.swapaxes(couplings, 1, 2), order='c')
        if zero_diag:  # With zero diag is much better; I just check things...
            for n in range(self.N):
                couplings[n, n, :, :] *= 0

        fields = np.log(fi_PC) - np.tensordot(
            couplings, fi_PC, axes=[(1, 3), (0, 1)])
        fields -= fields.mean(-1)[:, np.newaxis]

        self.layer.couplings = couplings
        self.layer.fields = fields
        if verbose:
            fig, ax = plt.subplots()
            ax2 = ax.twinx()
            ax.plot(self.DeltaL)
            ax2.semilogy(self.lam, c='red')
            ax.set_ylabel(r'$\Delta L$', color='blue')
            ax2.set_ylabel('Mode variance', color='red')
            for tl in ax.get_yticklabels():
                tl.set_color('blue')
            for tl in ax2.get_yticklabels():
                tl.set_color('red')
Exemple #15
0
def get_cross_derivatives_dReLU(V_pos, psi_pos, hlayer, n_cv, weights=None):
    # a = 2.0/(1.0/hlayer.a_plus + 1.0/hlayer.a_minus)
    # eta = 0.5* (a/hlayer.a_plus - a/hlayer.a_minus)
    # theta = (1.-eta**2)/2. * (hlayer.theta_plus+hlayer.theta_minus)
    # b = (1.+eta)/2. * hlayer.theta_plus - (1.-eta)/2. * hlayer.theta_minus
    db_dw = average(V_pos, c=n_cv, weights=weights)
    a = hlayer.a[np.newaxis, :]
    eta = hlayer.eta[np.newaxis, :]
    theta = hlayer.theta[np.newaxis, :]
    b = hlayer.b[np.newaxis, :]

    psi = psi_pos

    psi_plus = (-np.sqrt(1 + eta) *
                (psi - b) + theta / np.sqrt(1 + eta)) / np.sqrt(a)
    psi_minus = (np.sqrt(1 - eta) *
                 (psi - b) + theta / np.sqrt(1 - eta)) / np.sqrt(a)

    Phi_plus = erf_times_gauss(psi_plus)
    Phi_minus = erf_times_gauss(psi_minus)

    Z = Phi_plus * np.sqrt(1 + eta) + Phi_minus * np.sqrt(1 - eta)

    p_plus = 1 / (1 + (Phi_minus * np.sqrt(1 - eta)) /
                  (Phi_plus * np.sqrt(1 + eta)))
    nans = np.isnan(p_plus)
    p_plus[nans] = 1.0 * (np.abs(psi_plus[nans]) > np.abs(psi_minus[nans]))
    p_minus = 1 - p_plus

    e = (psi - b) * (1 + eta * (p_plus - p_minus)) - theta * (
        p_plus - p_minus) + 2 * eta * np.sqrt(a) / Z
    v = eta * (p_plus - p_minus) + p_plus * p_minus * (
        2 * theta / np.sqrt(a) - 2 * eta * (psi - b) / np.sqrt(a)) * (
            2 * theta / np.sqrt(a) - 2 * eta *
            (psi - b) / np.sqrt(a) - np.sqrt(1 + eta) / Phi_plus -
            np.sqrt(1 - eta) / Phi_minus) - 2 * eta * e / (np.sqrt(a) * Z)

    dpsi_plus_dpsi = -np.sqrt((1 + eta) / a)
    dpsi_minus_dpsi = np.sqrt((1 - eta) / a)
    dpsi_plus_dtheta = 1 / np.sqrt(a * (1 + eta))
    dpsi_minus_dtheta = 1 / np.sqrt(a * (1 - eta))
    # dpsi_plus_da = -1.0/(2*a) * psi_plus
    # dpsi_minus_da = -1.0/(2*a) * psi_minus

    dpsi_plus_deta = -1.0 / (2 * np.sqrt(a * (1 + eta))) * ((psi - b) + theta /
                                                            (1 + eta))
    dpsi_minus_deta = -1.0 / (2 * np.sqrt(a *
                                          (1 - eta))) * ((psi - b) - theta /
                                                         (1 - eta))

    # d2psi_plus_dadpsi = 0.5 * np.sqrt((1+eta)/a**3 )
    d2psi_plus_dthetadpsi = 0
    d2psi_plus_detadpsi = -0.5 / np.sqrt((1 + eta) * a)
    # d2psi_minus_dadpsi = -0.5 * np.sqrt((1-eta)/a**3 )
    d2psi_minus_dthetadpsi = 0
    d2psi_minus_detadpsi = -0.5 / np.sqrt((1 - eta) * a)

    dp_plus_dpsi = p_plus * p_minus * (
        (psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi -
        (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi)
    dp_plus_dtheta = p_plus * p_minus * (
        (psi_plus - 1 / Phi_plus) * dpsi_plus_dtheta -
        (psi_minus - 1 / Phi_minus) * dpsi_minus_dtheta)
    # dp_plus_da = p_plus * p_minus * ( (psi_plus-1/Phi_plus) * dpsi_plus_da  - (psi_minus-1/Phi_minus) * dpsi_minus_da )
    dp_plus_deta = p_plus * p_minus * (
        (psi_plus - 1 / Phi_plus) * dpsi_plus_deta -
        (psi_minus - 1 / Phi_minus) * dpsi_minus_deta + 1 / (1 - eta**2))


    d2p_plus_dpsi2 = -(p_plus-p_minus) * p_plus * p_minus * ( (psi_plus-1/Phi_plus) * dpsi_plus_dpsi  - (psi_minus-1/Phi_minus) * dpsi_minus_dpsi )**2 \
    + p_plus * p_minus * ( (dpsi_plus_dpsi)**2 *  (1+ (psi_plus-1/Phi_plus)/Phi_plus) - (dpsi_minus_dpsi)**2 * (1+ (psi_minus-1/Phi_minus)/Phi_minus) )

    # d2p_plus_dadpsi = -(p_plus-p_minus) * ( (psi_plus-1/Phi_plus) * dpsi_plus_dpsi  - (psi_minus-1/Phi_minus) * dpsi_minus_dpsi ) * (dp_plus_da)\
    # + p_plus * p_minus * ( (dpsi_plus_dpsi* dpsi_plus_da) *  (1+ (psi_plus-1/Phi_plus)/Phi_plus) - (dpsi_minus_dpsi *dpsi_minus_da) * (1+ (psi_minus-1/Phi_minus)/Phi_minus) \
    # + (d2psi_plus_dadpsi) * (psi_plus-1/Phi_plus) - (d2psi_minus_dadpsi) * (psi_minus-1/Phi_minus) )

    d2p_plus_dthetadpsi = -(p_plus-p_minus) * ( (psi_plus-1/Phi_plus) * dpsi_plus_dpsi  - (psi_minus-1/Phi_minus) * dpsi_minus_dpsi ) * (dp_plus_dtheta)\
    + p_plus * p_minus * ( (dpsi_plus_dpsi* dpsi_plus_dtheta) *  (1+ (psi_plus-1/Phi_plus)/Phi_plus) - (dpsi_minus_dpsi *dpsi_minus_dtheta) * (1+ (psi_minus-1/Phi_minus)/Phi_minus) \
    + (d2psi_plus_dthetadpsi) * (psi_plus-1/Phi_plus) - (d2psi_minus_dthetadpsi) * (psi_minus-1/Phi_minus) )

    d2p_plus_detadpsi = -(p_plus-p_minus) * ( (psi_plus-1/Phi_plus) * dpsi_plus_dpsi  - (psi_minus-1/Phi_minus) * dpsi_minus_dpsi ) * (dp_plus_deta)\
    + p_plus * p_minus * ( (dpsi_plus_dpsi* dpsi_plus_deta) *  (1+ (psi_plus-1/Phi_plus)/Phi_plus) - (dpsi_minus_dpsi *dpsi_minus_deta) * (1+ (psi_minus-1/Phi_minus)/Phi_minus) \
    + (d2psi_plus_detadpsi) * (psi_plus-1/Phi_plus) - (d2psi_minus_detadpsi) * (psi_minus-1/Phi_minus) )

    dlogZ_dpsi = (p_plus * (psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi +
                  p_minus * (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi)
    dlogZ_dtheta = (p_plus * (psi_plus - 1 / Phi_plus) * dpsi_plus_dtheta +
                    p_minus * (psi_minus - 1 / Phi_minus) * dpsi_minus_dtheta)
    # dlogZ_da = (p_plus * (psi_plus-1/Phi_plus)* dpsi_plus_da + p_minus * (psi_minus-1/Phi_minus) * dpsi_minus_da )
    dlogZ_deta = (p_plus * (psi_plus - 1 / Phi_plus) * dpsi_plus_deta +
                  p_minus * (psi_minus - 1 / Phi_minus) * dpsi_minus_deta +
                  0.5 * (p_plus / (1 + eta) - p_minus / (1 - eta)))

    de_dpsi = (1 + v)
    de_db = -de_dpsi
    # de_da = 2*((psi-b) * eta - theta) * dp_plus_da + eta/(Z*np.sqrt(a)) - 2*eta*np.sqrt(a)/Z * dlogZ_da
    de_dtheta = -(p_plus - p_minus) + 2 * (
        (psi - b) * eta -
        theta) * dp_plus_dtheta - 2 * eta * np.sqrt(a) / Z * dlogZ_dtheta
    de_deta = (psi - b) * (p_plus - p_minus) + 2 * (
        (psi - b) * eta - theta) * dp_plus_deta + 2 * np.sqrt(
            a) / Z - 2 * eta * np.sqrt(a) / Z * dlogZ_deta


    dv_dpsi = 4 * eta * dp_plus_dpsi\
    + 2*( (psi-b)*eta-theta) * d2p_plus_dpsi2 \
    - 2* eta/(np.sqrt(a)*Z) * ( de_dpsi - e*dlogZ_dpsi )

    dv_db = -dv_dpsi

    # dv_da = eta * 2 * dp_plus_da \
    # + 2 * ((psi-b)*eta - theta) * d2p_plus_dadpsi \
    # -2 * eta/(Z * np.sqrt(a)) * ( -e/(2*a) - e*dlogZ_da + de_da )

    dv_dtheta =  2 * eta *  dp_plus_dtheta \
    - 2 * dp_plus_dpsi \
    + 2 * ((psi-b)*eta - theta) * d2p_plus_dthetadpsi \
    -2 * eta/(Z * np.sqrt(a)) * ( - e*dlogZ_dtheta + de_dtheta )

    dv_deta = (p_plus-p_minus) \
    + 2 * eta * dp_plus_deta \
    + 2 * (psi-b) * dp_plus_dpsi \
    + 2 * ((psi-b)*eta - theta) * d2p_plus_detadpsi \
    -2 * 1/(Z * np.sqrt(a)) * (e - e*eta*dlogZ_deta + eta*de_deta )

    var_e = average(e**2, weights=weights) - average(e, weights=weights)**2
    mean_v = average(v, weights=weights)

    # dmean_v_da = average(dv_da,weights=weights)
    dmean_v_db = average(dv_db, weights=weights)
    dmean_v_dtheta = average(dv_dtheta, weights=weights)
    dmean_v_deta = average(dv_deta, weights=weights)

    # dvar_e_da = 2* (average(e*de_da,weights=weights) -average(e,weights=weights) * average(de_da,weights=weights) )
    dvar_e_db = 2 * (
        average(e * de_db, weights=weights) -
        average(e, weights=weights) * average(de_db, weights=weights))
    dvar_e_dtheta = 2 * (
        average(e * de_dtheta, weights=weights) -
        average(e, weights=weights) * average(de_dtheta, weights=weights))
    dvar_e_deta = 2 * (
        average(e * de_deta, weights=weights) -
        average(e, weights=weights) * average(de_deta, weights=weights))

    tmp = np.sqrt((1 + mean_v)**2 + 4 * var_e)
    denominator = tmp
    # denominator = (tmp - dvar_e_da- 0.5 * dmean_v_da * (1+mean_v+tmp))
    # denominator = np.maximum( denominator, 0.5) # For numerical stability.

    da_db = (dvar_e_db + 0.5 * dmean_v_db * (1 + mean_v + tmp)) / denominator
    da_dtheta = (dvar_e_dtheta + 0.5 * dmean_v_dtheta *
                 (1 + mean_v + tmp)) / denominator
    da_deta = (dvar_e_deta + 0.5 * dmean_v_deta *
               (1 + mean_v + tmp)) / denominator

    dmean_v_dw = average_product(dv_dpsi,
                                 V_pos,
                                 c1=1,
                                 c2=n_cv,
                                 weights=weights)

    if n_cv > 1:
        dvar_e_dw = 2 * (
            average_product(e * de_dpsi, V_pos, c1=1, c2=n_cv, weights=weights)
            - average(e, weights=weights)[:, np.newaxis, np.newaxis] *
            average_product(de_dpsi, V_pos, c1=1, c2=n_cv, weights=weights))
        da_dw = (dvar_e_dw + 0.5 * dmean_v_dw *
                 (1 + mean_v + tmp)[:, np.newaxis, np.newaxis]
                 ) / denominator[:, np.newaxis, np.newaxis]

    else:
        dvar_e_dw = 2 * (
            average_product(e * de_dpsi, V_pos, c1=1, c2=1, weights=weights) -
            average(e, weights=weights)[:, np.newaxis] *
            average_product(de_dpsi, V_pos, c1=1, c2=1, weights=weights))
        da_dw = (dvar_e_dw + 0.5 * dmean_v_dw *
                 (1 + mean_v + tmp)[:, np.newaxis]) / denominator[:,
                                                                  np.newaxis]

    return db_dw, da_db, da_dtheta, da_deta, da_dw