Exemplo n.º 1
0
    def M_step(self, anneal, model_params, my_suff_stat, my_data):
        """Discrete Sparse Coding M-Step

        This function is responsible for finding the optimal model parameters given an approximation of the posterior distribution.
        
        Parameters
        ----------
        anneal : Annealing object
            Annealing type obje ct containing training schedule information
                anneal['T'] :           Temperature for det. annealing
                anneal['N_cut_factor']: 0. no truncation; 1. trunc. according to model
        model_params : dict
            dictionary containing model parameters
                model_params['W']:     (H,D) ndarray
                    linear dictionary
                model_params['pi']:    (K,) ndarray
                    prior parameters
                model_params['sigma']:  float
                    standard deviation of noise model
        my_suff_stat : dict
            dictionary containing inforamtion about the joint distribution
                my_suff_stat['logpj']:  (my_N,no_states) ndarray
                    logarithm of joint of data and latent variable states 
        my_data : dict
            data dictionary
                my_data['y']:     (my_N,D) ndarray
                    datapoints
                my_data['candidates']: (my_n,Hprime) 
                    Candidate H's according to selection func.
                    
        Returns
        -------
        dict
            dictionary containing updated model parameters
                dict['W']:     (H,D) ndarray
                    linear dictionary
                dict['pi']:    (K,) ndarray
                    prior parameters
                dict['sigma']:  float
                    standard deviation of noise model
        
        """
        comm = self.comm
        W = model_params['W'].T
        pi = model_params['pi']
        sigma = model_params['sigma']

        # Read in data:
        my_y = my_data['y'].copy()
        candidates = my_data['candidates']
        logpj_all = my_suff_stat['logpj']
        all_denoms = np.exp(logpj_all).sum(axis=1)
        my_N, D = my_y.shape
        N = comm.allreduce(my_N)

        A_pi_gamma = self.get_scaling_factors(model_params['pi'])
        dlog.append("prior_mass", A_pi_gamma)
        # _, A_pi_gamma, _=self.get_scaling_factors(model_params['pi'])

        #Truncate data
        N_use, my_y, candidates, logpj_all = self._get_sorted_data(
            N, anneal, A_pi_gamma, all_denoms, candidates, logpj_all, my_y)
        my_N, D = my_y.shape  # update my_N

        # Precompute
        corr_all = logpj_all.max(axis=1)  # shape: (my_N,)
        pjb_all = np.exp(logpj_all -
                         corr_all[:, None])  # shape: (my_N, no_states)

        #Log-Likelihood:
        L = self.get_likelihood(D, sigma, logpj_all, N_use)
        dlog.append('L', L)
        # Allocate
        my_Wp = np.zeros_like(W)  # shape (H, D)
        my_Wq = np.zeros((self.H, self.H))  # shape (H, H)
        my_pi = np.zeros_like(pi)  # shape (K)
        my_sigma = 0.0  #
        SM = self.state_matrix
        SSM = self.single_state_matrix

        # Iterate over all datapoints
        for n in range(my_N):
            y = my_y[n, :]  # length D
            cand = candidates[n, :]  # length Hprime
            pjb = pjb_all[n, :]
            this_Wp = np.zeros_like(
                my_Wp)  # numerator for current datapoint (H, D)
            this_Wq = np.zeros_like(
                my_Wq)  # denominator for current datapoint (H, H)
            this_pi = np.zeros_like(
                pi)  # numerator for pi update (current datapoint)

            # Handle hidden states with 0 active causes
            this_pi[self._K_0] = self.H * pjb[0]
            this_sigma = pjb[0] * (y**2).sum()

            # Handle hidden states with 1 active cause
            #FIX: I am sure I need to multiply with pi somewhere here
            c = 0
            # import ipdb;ipdb.set_trace()
            for state in range(self.K):
                if state == self._K_0:
                    continue
                sspjb = pjb[c * self.H + 1:(c + 1) * self.H + 1]
                # this_Wp  += np.outer(sspjb,y.T)
                # this_Wq  += sspjb[:,None] * SSM[c*self.H:(c+1)*self.H]

                this_pi[state] += sspjb.sum()

                recons = self.states[state] * W
                sqe = ((recons - y)**2).sum(1)
                this_sigma += (sspjb * sqe).sum()

                c += 1
            this_pi[self._K_0] += ((self.H - 1) *
                                   pjb[1:(self.K - 1) * self.H + 1]).sum()
            this_Wp += np.dot(np.outer(y, pjb[1:(self.K - 1) * self.H + 1]),
                              SSM).T
            # this_Wq_tmp           = np.zeros_like(my_Wq[cand])
            # this_Wq_tmp[:,cand]   = np.dot(pjb[(self.K-1)*self.H+1:] * SM.T,SM)
            this_Wq += np.dot(pjb[1:(self.K - 1) * self.H + 1] * SSM.T, SSM)

            if self.gamma > 1:
                # Handle hidden states with more than 1 active cause
                this_Wp[cand] += np.dot(
                    np.outer(y, pjb[(self.K - 1) * self.H + 1:]), SM).T
                this_Wq_tmp = np.zeros_like(my_Wq[cand])
                this_Wq_tmp[:, cand] = np.dot(
                    pjb[(self.K - 1) * self.H + 1:] * SM.T, SM)
                this_Wq[cand] += this_Wq_tmp

                this_pi += np.inner(pjb[(self.K - 1) * self.H + 1:],
                                    self.state_abs)

                W_ = W[cand]  # is (Hprime x D)
                Wbar = np.dot(SM, W_)
                this_sigma += (pjb[(self.K - 1) * self.H + 1:] *
                               ((Wbar - y)**2).sum(axis=1)).sum()
            #Scale down
            denom = pjb.sum()
            my_Wp += this_Wp / denom
            my_Wq += this_Wq / denom

            my_pi += this_pi / denom

            my_sigma += this_sigma / denom / D

        #Calculate updated W
        Wp = np.empty_like(my_Wp)
        Wq = np.empty_like(my_Wq)
        comm.Allreduce([my_Wp, MPI.DOUBLE], [Wp, MPI.DOUBLE])
        comm.Allreduce([my_Wq, MPI.DOUBLE], [Wq, MPI.DOUBLE])
        # W_new  = np.dot(np.linalg.pinv(Wq), Wp)
        rcond = -1
        if float(np.__version__[2:]) >= 14.0:
            rcond = None
        W_new = np.linalg.lstsq(
            Wq, Wp, rcond=rcond)[0]  # TODO check and switch to this one

        # Calculate updated pi
        pi_new = np.empty_like(pi)
        # pi_new = E_pi_gamma * comm.allreduce(my_pi) / H / N_use
        for i in range(self.K):
            pi_new[i] = comm.allreduce(my_pi[i]) / comm.allreduce(my_pi.sum())

        eps = 1e-6
        if np.any(pi_new < eps):
            which_lo = pi_new < eps
            which_hi = pi_new >= eps
            pi_new[which_lo] += eps - pi_new[which_lo]
            pi_new[which_hi] -= (eps * np.sum(which_lo)) / np.sum(which_hi)

        if 'penalty' in list(self.__dict__.keys()):
            self.penalty
            if self.penalty > pi_new[self._K_0]:
                r = (1 - self.penalty) / (1 - pi_new[self._K_0])
                pi_new[pi_new != 0] = pi_new[pi_new != 0] * r
                pi_new[self._K_0] = self.penalty
                pi_new /= pi_new.sum()

        # Calculate updated sigma
        sigma_new = np.sqrt(comm.allreduce(my_sigma) / N_use)

        if 'W' not in self.to_learn:
            W_new = W
        if 'pi' not in self.to_learn:
            pi_new = pi
        if 'sigma' not in self.to_learn:
            sigma_new = sigma

        for param in anneal.crit_params:
            exec('this_param = ' + param)
            anneal.dyn_param(param, this_param)

        dlog.append('N_use', N_use)

        return {
            'W': W_new.transpose(),
            'pi': pi_new,
            'sigma': sigma_new,
            'Q': 0.
        }
Exemplo n.º 2
0
    def M_step(self, anneal, model_params, my_suff_stat, my_data):
        """ BSC M_step

        my_data variables used:
            
            my_data['y']           Datapoints
            my_data['candidates']         Candidate H's according to selection func.

        Annealing variables used:

            anneal['T']            Temperature for det. annealing
            anneal['N_cut_factor'] 0.: no truncation; 1. trunc. according to model

        """

        comm = self.comm
        H, Hprime = self.H, self.Hprime
        gamma = self.gamma
        W = model_params['W'].T
        pies = model_params['pi']
        sigma = model_params['sigma']
        mu = model_params['mu']

        # Read in data:
        my_y = my_data['y'].copy()
        candidates = my_data['candidates']
        logpj_all = my_suff_stat['logpj']
        all_denoms = np.exp(logpj_all).sum(axis=1)

        my_N, D = my_y.shape
        N = comm.allreduce(my_N)

        # Joerg's data noise idea
        data_noise_scale = anneal['data_noise']
        if data_noise_scale > 0:
            my_y += my_data['data_noise']

        SM = self.state_matrix  # shape: (no_states, Hprime)

        # To compute et_loglike:
        my_ldenom_sum = 0.0
        ldenom_sum = 0.0

        # Precompute factor for pi update
        A_pi_gamma = 0
        B_pi_gamma = 0
        for gamma_p in range(gamma + 1):
            A_pi_gamma += comb(H, gamma_p) * (pies**gamma_p) * (
                (1 - pies)**(H - gamma_p))
            B_pi_gamma += gamma_p * comb(H, gamma_p) * (pies**gamma_p) * (
                (1 - pies)**(H - gamma_p))
        E_pi_gamma = pies * H * A_pi_gamma / B_pi_gamma

        # Truncate data
        if anneal['Ncut_factor'] > 0.0:
            tracing.tracepoint("M_step:truncating")
            #alpha = 0.9 # alpha from ET paper
            #N_use = int(alpha * (N * (1 - (1 - A_pi_gamma) * anneal['Ncut_factor'])))
            N_use = int(N * (1 - (1 - A_pi_gamma) * anneal['Ncut_factor']))
            cut_denom = parallel.allsort(all_denoms)[-N_use]
            which = np.array(all_denoms >= cut_denom)
            candidates = candidates[which]
            logpj_all = logpj_all[which]
            my_y = my_y[which]
            my_N, D = my_y.shape
            N_use = comm.allreduce(my_N)
        else:
            N_use = N
        dlog.append('N', N_use)

        # Calculate truncated Likelihood
        L = H * np.log(1 - pies) - 0.5 * D * np.log(
            2 * pi * sigma**2) - np.log(A_pi_gamma)
        Fs = np.log(np.exp(logpj_all).sum(axis=1)).sum()
        L += comm.allreduce(Fs) / N_use
        dlog.append('L', L)

        # Precompute
        pil_bar = np.log(pies / (1. - pies))
        corr_all = logpj_all.max(axis=1)  # shape: (my_N,)
        pjb_all = np.exp(logpj_all -
                         corr_all[:, None])  # shape: (my_N, no_states)

        # Allocate
        my_Wp = np.zeros_like(W)  # shape (H, D)
        my_Wq = np.zeros((H, H))  # shape (H, H)
        my_pi = 0.0  #
        my_sigma = 0.0  #
        #my_mup = np.zeros_like(W)  # shape (H, D)
        #my_muq = np.zeros((H,H))   # shape (H, H)
        my_mus = np.zeros(H)  # shape D
        data_sum = my_y.sum(axis=0)  # sum over all data points for mu update

        ## Calculate mu
        #for n in xrange(my_N):
        #tracing.tracepoint("Calculationg offset")
        #y     = my_y[n,:]                  # length D
        #cand  = candidates[n,:] # length Hprime
        #logpj = logpj_all[n,:] # length no_states
        #corr  = corr_all[n]        # scalar
        #pjb = pjb_all[n, :]

        ## Zero active hidden cause (do nothing for the W and pi case)
        ## this_Wp += 0.     # nothing to do
        ## this_Wq += 0.     # nothing to do
        ## this_pi += 0.     # nothing to do

        ## One active hidden cause
        #this_mup = np.outer(pjb[1:(H+1)],y)
        #this_muq = pjb[1:(H+1)] * np.identity(H)
        #this_mus = pjb[1:(H+1)]

        ## Handle hidden states with more than 1 active cause
        #this_mup[cand]      += np.dot(np.outer(y,pjb[(1+H):]),SM).T
        #this_muq_tmp         = np.zeros_like(my_muq[cand])
        #this_muq_tmp[:,cand] = np.dot(pjb[(1+H):] * SM.T,SM)
        #this_muq[cand]      += this_muq_tmp
        #this_mus[cand]      += np.inner(SM.T,pjb[(1+H):])

        #denom = pjb.sum()
        #my_mup += this_mup / denom
        #my_muq += this_muq / denom
        #my_mus += this_mus / denom

        ## Calculate updated mu
        #if 'mu' in self.to_learn:
        #tracing.tracepoint("M_step:update mu")
        #mup = np.empty_like(my_mup)
        #muq = np.empty_like(my_muq)
        #mus = np.empty_like(my_mus)
        #all_data_sum = np.empty_like(data_sum)
        #comm.Allreduce( [my_mup, MPI.DOUBLE], [mup, MPI.DOUBLE] )
        #comm.Allreduce( [my_muq, MPI.DOUBLE], [muq, MPI.DOUBLE] )
        #comm.Allreduce( [my_mus, MPI.DOUBLE], [mus, MPI.DOUBLE] )
        #comm.Allreduce( [data_sum, MPI.DOUBLE], [all_data_sum, MPI.DOUBLE] )
        #mu_numer = all_data_sum - np.dot(mus,np.dot(np.linalg.inv(muq), mup))
        #mu_denom = my_N - np.dot(mus,np.dot(np.linalg.inv(muq), mus))
        #mu_new  = mu_numer/ mu_denom
        #else:
        #mu_new = mu

        # Iterate over all datapoints
        tracing.tracepoint("M_step:iterating")
        for n in range(my_N):
            y = my_y[n, :] - mu  # length D
            cand = candidates[n, :]  # length Hprime
            pjb = pjb_all[n, :]

            this_Wp = np.zeros_like(
                my_Wp)  # numerator for current datapoint   (H, D)
            this_Wq = np.zeros_like(
                my_Wq)  # denominator for current datapoint (H, H)
            this_pi = 0.0  # numerator for pi update (current datapoint)

            # Zero active hidden cause (do nothing for the W and pi case)
            # this_Wp += 0.     # nothing to do
            # this_Wq += 0.     # nothing to do
            # this_pi += 0.     # nothing to do

            # One active hidden cause
            this_Wp = np.outer(pjb[1:(H + 1)], y)
            this_Wq = pjb[1:(H + 1)] * np.identity(H)
            this_pi = pjb[1:(H + 1)].sum()
            this_mus = pjb[1:(H + 1)].copy()

            # Handle hidden states with more than 1 active cause
            this_Wp[cand] += np.dot(np.outer(y, pjb[(1 + H):]), SM).T
            this_Wq_tmp = np.zeros_like(my_Wq[cand])
            this_Wq_tmp[:, cand] = np.dot(pjb[(1 + H):] * SM.T, SM)
            this_Wq[cand] += this_Wq_tmp
            this_pi += np.inner(pjb[(1 + H):], SM.sum(axis=1))
            this_mus[cand] += np.inner(SM.T, pjb[(1 + H):])

            denom = pjb.sum()
            my_Wp += this_Wp / denom
            my_Wq += this_Wq / denom
            my_pi += this_pi / denom
            my_mus += this_mus / denom

        # Calculate updated W
        if 'W' in self.to_learn:
            tracing.tracepoint("M_step:update W")
            Wp = np.empty_like(my_Wp)
            Wq = np.empty_like(my_Wq)
            comm.Allreduce([my_Wp, MPI.DOUBLE], [Wp, MPI.DOUBLE])
            comm.Allreduce([my_Wq, MPI.DOUBLE], [Wq, MPI.DOUBLE])
            #W_new  = np.dot(np.linalg.inv(Wq), Wp)
            #W_new  = np.linalg.solve(Wq, Wp)      # TODO check and switch to this one
            rcond = -1
            if float(np.__version__[2:]) >= 14.0:
                rcond = None
            W_new = np.linalg.lstsq(
                Wq, Wp, rcond=rcond)[0]  # TODO check and switch to this one
        else:
            W_new = W

        # Calculate updated pi
        if 'pi' in self.to_learn:
            tracing.tracepoint("M_step:update pi")
            pi_new = E_pi_gamma * comm.allreduce(my_pi) / H / N_use
        else:
            pi_new = pies

        # Calculate updated sigma
        if 'sigma' in self.to_learn:
            tracing.tracepoint("M_step:update sigma")
            # Loop for sigma update:
            for n in range(my_N):
                y = my_y[n, :] - mu  # length D
                cand = candidates[n, :]  # length Hprime
                logpj = logpj_all[n, :]  # length no_states
                corr = logpj.max()  # scalar
                pjb = np.exp(logpj - corr)

                # Zero active hidden causes
                this_sigma = pjb[0] * (y**2).sum()

                # Hidden states with one active cause
                this_sigma += (pjb[1:(H + 1)] * ((W - y)**2).sum(axis=1)).sum()

                # Handle hidden states with more than 1 active cause
                SM = self.state_matrix  # is (no_states, Hprime)
                W_ = W[cand]  # is (Hprime x D)
                Wbar = np.dot(SM, W_)
                this_sigma += (pjb[(H + 1):] *
                               ((Wbar - y)**2).sum(axis=1)).sum()

                denom = pjb.sum()
                my_sigma += this_sigma / denom

            sigma_new = np.sqrt(comm.allreduce(my_sigma) / D / N_use)
        else:
            sigma_new = sigma

        # Calculate updated mu:
        if 'mu' in self.to_learn:
            tracing.tracepoint("M_step:update mu")
            mus = np.empty_like(my_mus)
            all_data_sum = np.empty_like(data_sum)
            comm.Allreduce([my_mus, MPI.DOUBLE], [mus, MPI.DOUBLE])
            comm.Allreduce([data_sum, MPI.DOUBLE], [all_data_sum, MPI.DOUBLE])
            mu_new = all_data_sum / my_N - np.inner(W_new.T / my_N, mus)
        else:
            mu_new = mu

        for param in anneal.crit_params:
            exec('this_param = ' + param)
            anneal.dyn_param(param, this_param)

        dlog.append('N_use', N_use)

        return {'W': W_new.T, 'pi': pi_new, 'sigma': sigma_new, 'mu': mu_new}
Exemplo n.º 3
0
    def M_step(self, anneal, model_params, my_suff_stat, my_data):
        """ MCA M_step

        my_data variables used:
            
            my_data['y']           Datapoints
            my_data['candidates']         Candidate H's according to selection func.

        Annealing variables used:

            anneal['T']            Temperature for det. annealing AND softmax
            anneal['N_cut_factor'] 0.: no truncation; 1. trunc. according to model

        """
        comm      = self.comm
        H, Hprime = self.H, self.Hprime
        gamma     = self.gamma
        W         = model_params['W'].T
        pies      = model_params['pi']
        sigma     = model_params['sigma']

        # Read in data:
        my_y      = my_data['y']
        my_cand   = my_data['candidates']
        my_logpj  = my_suff_stat['logpj']
        my_N, D   = my_y.shape
        N         = comm.allreduce(my_N)

        state_mtx = self.state_matrix        # shape: (no_states, Hprime)
        state_abs = self.state_abs           # shape: (no_states,)
        no_states = len(state_abs)


        # To compute et_loglike:
        my_ldenom_sum = 0.0
        ldenom_sum = 0.0

        # Precompute 
        T        = anneal['T'] 
        T_rho    = np.maximum(T, self.rho_temp_bound)
        rho      = 1./(1.-1./T_rho)
        beta     = 1./T
        pre0     = (1.-rho)/rho
        pre1     = -1./2./sigma/sigma
        pil_bar  = np.log( pies/(1.-pies) )
        Wl       = np.log(W)
        Wrho     = np.exp(rho * Wl)
        Wsquared = W*W

        # Some asserts
        assert np.isfinite(pil_bar).all()
        assert np.isfinite(Wl).all()
        assert np.isfinite(Wrho).all()
        assert (Wrho > 1e-86).all()

        my_corr  = beta*((my_logpj).max(axis=1))            # shape: (my_N,)
        my_pjb   = np.exp(beta*my_logpj - my_corr[:, None]) # shape: (my_N, no_states)

        # Precompute factor for pi/gamma update
        A_pi_gamma = 0.; B_pi_gamma = 0.
        for gp in range(0, self.gamma+1):
            a = comb(H, gp, exact=1) * pies**gp * (1.-pies)**(H-gp)
            A_pi_gamma += a
            B_pi_gamma += gp * a

        # Truncate data
        if anneal['Ncut_factor'] > 0.0:
            tracing.tracepoint("M_step:truncating")
            my_denoms = np.log(my_pjb.sum(axis=1)) + my_corr
            N_use = int(N * (1 - (1 - A_pi_gamma) * anneal['Ncut_factor']))

            cut_denom = parallel.allsort(my_denoms)[-N_use]
            which     = np.array(my_denoms >= cut_denom)
            
            my_y     = my_y[which]
            my_cand  = my_cand[which]
            my_logpj = my_logpj[which]
            my_pjb   = my_pjb[which]
            my_corr  = my_corr[which]
            my_N, D  = my_y.shape
            N_use    = comm.allreduce(my_N)
        else:
            N_use = N
        dlog.append('N_use', N_use)
            
        # Allocate suff-stat arrays
        my_Wp    = np.zeros_like(W)  # shape (H, D)
        my_Wq    = np.zeros_like(W)  # shape (H, D)
        my_pi    = 0.0               #
        my_sigma = 0.0               #

        # Iterate over all datapoints
        for n in range(my_N):
            tracing.tracepoint("M_step:iterating")
            y     = my_y[n,:]             # shape (D,)
            cand  = my_cand[n,:]          # shape (Hprime,)
            logpj = my_logpj[n,:]         # shape (no_states,)
            pjb   = my_pjb[n,:]           # shape (no_states,)
            corr  = my_corr[n]            # scalar

            this_Wp = np.zeros_like(W)    # numerator for W (current datapoint)   (H, D)
            this_Wq = np.zeros_like(W)    # denominator for W (current datapoint) (H, D)
            this_pi = 0.0                 # numerator for pi update (current datapoint)
            this_sigma = 0.0              # numerator for gamma update (current datapoint)

            # Zero active hidden causes
            # this_Wp += 0.     # nothing to do
            # this_Wq += 0.     # nothing to do
            # this_pi += 0.     # nothing to do
            this_sigma += pjb[0] * (y**2).sum()

            # One active hidden cause
            this_Wp    += (pjb[1:(H+1),None] * Wsquared[:,:]) * y[None, :]
            this_Wq    += (pjb[1:(H+1),None] * Wsquared[:,:])
            this_pi    += pjb[1:(H+1)].sum()
            this_sigma += (pjb[1:(H+1)] * ((W-y)**2).sum(axis=1)).sum()

            # Handle hidden states with more than 1 active cause
            W_    = W[cand]                                    # is (Hprime, D)
            Wl_   = Wl[cand]                                   # is (   "    ")
            Wrho_ = Wrho[cand]                                 # is (   "    ")

            Wlrhom1 = (rho-1)*Wl_                              # is (Hprime, D)
            Wlbar   = np.log(np.dot(state_mtx,Wrho_)) / rho    # is (no_states, D)
            Wbar    = np.exp(Wlbar)                            # is (no_states, D)
            blpj    = beta*logpj[1+H:] - corr                  # is (no_states,)

            Aid  = (state_mtx[:,:, None] * np.exp(blpj[:,None,None] + (1-rho)*Wlbar[:, None, :] + Wlrhom1[None, :, :])).sum(axis=0)

            assert np.isfinite(Wlbar).all()
            assert np.isfinite(Wbar).all()
            assert np.isfinite(pjb).all()
            assert np.isfinite(Aid).all()

            this_Wp[cand] += Aid * y[None, :]                     
            this_Wq[cand] += Aid
            this_pi       += (pjb[1+H:] * state_abs).sum()
            this_sigma    += (pjb[1+H:] * ((Wbar-y)**2).sum(axis=1)).sum()

            denom     = pjb.sum()
            my_Wp    += this_Wp / denom
            my_Wq    += this_Wq / denom
            my_pi    += this_pi / denom
            my_sigma += this_sigma / denom

            my_ldenom_sum += np.log(np.sum(np.exp(logpj))) #For loglike computation


        # Calculate updated W
        if 'W' in self.to_learn:
            tracing.tracepoint("M_step:update W")

            Wp = np.empty_like(my_Wp)
            Wq = np.empty_like(my_Wq)

            assert np.isfinite(my_Wp).all()
            assert np.isfinite(my_Wq).all()

            comm.Allreduce( [my_Wp, MPI.DOUBLE], [Wp, MPI.DOUBLE] )
            comm.Allreduce( [my_Wq, MPI.DOUBLE], [Wq, MPI.DOUBLE] )

            # Make sure wo do not devide by zero
            tiny = np.finfo(Wq.dtype).tiny
            Wp[Wq < tiny] = 0.
            Wq[Wq < tiny] = tiny

            W_new = (Wp / Wq).T
        else:
            W_new = W.T

        # Calculate updated pi
        if 'pi' in self.to_learn:
            tracing.tracepoint("M_step:update pi")

            assert np.isfinite(my_pi).all()
            pi_new = A_pi_gamma / B_pi_gamma * pies * comm.allreduce(my_pi) / N_use
        else:
            pi_new = pies

        # Calculate updated sigma
        if 'sigma' in self.to_learn:               # TODO: XXX see LinCA XXX (merge!)
            tracing.tracepoint("M_step:update sigma")

            assert np.isfinite(my_sigma).all()
            sigma_new = np.sqrt(comm.allreduce(my_sigma) / D / N_use)
        else:
            sigma_new = sigma

        #Put all together and compute (always) et_approx_likelihood
        ldenom_sum = comm.allreduce(my_ldenom_sum)
        lAi = (H * np.log(1. - pi_new)) - ((D/2) * np.log(2*pi)) -( D * np.log(sigma_new))

        #For practical and et approx reasons we use: sum of restected respons=1
        loglike_et = (lAi * N_use) + ldenom_sum

        return { 'W': W_new, 'pi': pi_new, 'sigma': sigma_new , 'Q':loglike_et}
Exemplo n.º 4
0
    def M_step(self, anneal, model_params, my_suff_stat, my_data):
        """Ternary Sparse Coding M-Step
         
         This function is responsible for finding the optimal model parameters given an approximation of the posterior distribution.
         
         Parameters
         ----------
         anneal : Annealing object
             Annealing type obje ct containing training schedule information
                 anneal['T'] :           Temperature for det. annealing
                 anneal['N_cut_factor']: 0. no truncation; 1. trunc. according to model
         model_params : dict
             dictionary containing model parameters
                 model_params['W']:     (H,D) ndarray
                     linear dictionary
                 model_params['pi']:    (K,) ndarray
                     prior parameters
                 model_params['sigma']:  float
                     standard deviation of noise model
         my_suff_stat : dict
             dictionary containing inforamtion about the joint distribution
                 my_suff_stat['logpj']:  (my_N,no_states) ndarray
                     logarithm of joint of data and latent variable states 
         my_data : dict
             data dictionary
                 my_data['y']:     (my_N,D) ndarray
                     datapoints
                 my_data['candidates']: (my_n,Hprime) 
                     Candidate H's according to selection func.
         
         Returns
         -------
         dict
             dictionary containing updated model parameters
                 dict['W']:     (H,D) ndarray
                     linear dictionary
                 dict['pi']:    (K,) ndarray
                     prior parameters
                 dict['sigma']:  float
                     standard deviation of noise model
         
         """
        comm = self.comm
        H = self.H
        gamma = self.gamma
        W = model_params['W'].T
        pi = model_params['pi']
        sigma = model_params['sigma']

        # Read in data:
        my_y = my_data['y'].copy()
        candidates = my_data['candidates']
        logpj_all = my_suff_stat['logpj']
        all_denoms = np.exp(logpj_all).sum(axis=1)
        my_N, D = my_y.shape
        N = comm.allreduce(my_N)

        SM = self.state_matrix  #[SM_bool]        # shape: (no_states, Hprime)
        state_abs = np.abs(SM).sum(axis=1)

        # Precompute factor for pi update
        A_pi_gamma = 0.0
        B_pi_gamma = 0.0

        for gam1 in range(gamma + 1):
            for gam2 in range(gamma - gam1 + 1):
                cmb = comb(gam1, gam1) * comb(gam1 + gam2, gam2) * comb(
                    H, H - gam1 - gam2)
                A_pi_gamma += cmb * ((pi / 2)**(gam1 + gam2)) * (
                    (1 - pi)**(H - gam1 - gam2))
                B_pi_gamma += (gam1 + gam2) * cmb * (
                    (pi / 2)**(gam1 + gam2)) * ((1 - pi)**(H - gam1 - gam2))
        E_pi_gamma = pi * H * A_pi_gamma / B_pi_gamma

        #Truncate data
        if anneal['Ncut_factor'] > 0.0:
            #tracing.tracepoint("M_step:truncating")
            N_use = int(N * (1 - (1 - A_pi_gamma) * anneal['Ncut_factor']))
            cut_denom = parallel.allsort(all_denoms)[-N_use]
            which = np.array(all_denoms >= cut_denom)
            candidates = candidates[which]
            logpj_all = logpj_all[which]
            my_y = my_y[which]
            my_N, D = my_y.shape
            N_use = comm.allreduce(my_N)
        else:
            N_use = N

        #Log-Likelihood:
        L = -0.5 * D * np.log(2 * np.pi * sigma**2) - np.log(A_pi_gamma)

        Fs = np.log(np.exp(logpj_all).sum(axis=1)).sum()
        L += comm.allreduce(Fs) / N_use
        dlog.append('L', L)

        # Precompute
        corr_all = logpj_all.max(axis=1)  # shape: (my_N,)
        pjb_all = np.exp(logpj_all -
                         corr_all[:, None])  # shape: (my_N, no_states)
        # Allocate
        my_Wp = np.zeros_like(W)  # shape (H, D)
        my_Wq = np.zeros((H, H))  # shape (H, H)
        my_pi = 0.0  #
        my_sigma = 0.0  #

        # Iterate over all datapoints
        for n in range(my_N):
            #tracing.tracepoint("M_step:iterating")
            y = my_y[n, :]  # length D
            cand = candidates[n, :]  # length Hprime
            pjb = pjb_all[n, :]
            this_Wp = np.zeros_like(
                my_Wp)  # numerator for current datapoint   (H, D)
            this_Wq = np.zeros_like(
                my_Wq)  # denominator for current datapoint (H, H)
            this_pi = np.zeros_like(
                pi)  # numerator for pi update (current datapoint)
            # Handle hidden states with more than 1 active cause
            this_Wp[cand] += np.dot(np.outer(y, pjb), SM).T
            this_Wq_tmp = np.zeros_like(my_Wq[cand])
            this_Wq_tmp[:, cand] = np.dot(pjb * SM.T, SM)
            this_Wq[cand] += this_Wq_tmp
            this_pi += np.inner(pjb, state_abs)

            denom = pjb.sum()
            my_Wp += this_Wp / denom

            my_Wq += this_Wq / denom
            my_pi += this_pi / denom
        #Calculate updated W
        if 'W' in self.to_learn:
            #tracing.tracepoint("M_step:update W")
            Wp = np.empty_like(my_Wp)
            Wq = np.empty_like(my_Wq)
            comm.Allreduce([my_Wp, MPI.DOUBLE], [Wp, MPI.DOUBLE])
            comm.Allreduce([my_Wq, MPI.DOUBLE], [Wq, MPI.DOUBLE])
            W_new = np.dot(np.linalg.pinv(Wq), Wp)
        else:
            W_new = W

        # Calculate updated pi
        pi_new = np.empty_like(pi)
        if 'pi' in self.to_learn:
            #tracing.tracepoint("M_step:update pi")
            pi_new = E_pi_gamma * comm.allreduce(my_pi) / H / N_use
        else:
            pi_new = pi
        # Calculate updated sigma
        if 'sigma' in self.to_learn:
            #tracing.tracepoint("M_step:update sigma")
            # Loop for sigma update:
            for n in range(my_N):
                #tracing.tracepoint("M_step:update sigma iteration")
                y = my_y[n, :]  # length D
                cand = candidates[n, :]  # length Hprime
                logpj = logpj_all[n, :]  # length no_states
                corr = logpj.max()  # scalar
                pjb = np.exp(logpj - corr)

                # Zero active hidden causes
                #this_sigma = pjb[0] * (y**2).sum()

                # Hidden states with one active cause
                #this_sigma += (pjb[1:(H+1)] * ((W-y)**2).sum(axis=1)).sum()

                # Handle hidden states with more than 1 active cause
                #SM = self.state_matrix                 # is (no_states, Hprime)
                W_ = W[cand]  # is (Hprime x D)

                Wbar = np.dot(SM, W_)
                this_sigma = (pjb * ((Wbar - y)**2).sum(axis=1)).sum()

                denom = pjb.sum()
                my_sigma += this_sigma / denom

            sigma_new = np.sqrt(comm.allreduce(my_sigma) / D / N_use)
        else:
            sigma_new = sigma

        for param in anneal.crit_params:
            exec('this_param = ' + param)
            anneal.dyn_param(param, this_param)

        dlog.append('N_use', N_use)

        return {
            'W': W_new.transpose(),
            'pi': pi_new,
            'sigma': sigma_new,
            'Q': 0.
        }
Exemplo n.º 5
0
    def M_step(self, anneal, model_params, my_suff_stat, my_data):
        """ MCA M_step

        my_data variables used:
            
            my_data['y']           Datapoints
            my_data['candidates']         Candidate H's according to selection func.

        Annealing variables used:

            anneal['T']            Temperature for det. annealing AND softmax
            anneal['N_cut_factor'] 0.: no truncation; 1. trunc. according to model

        """
        comm      = self.comm
        H, Hprime = self.H, self.Hprime
        gamma     = self.gamma
        W         = model_params['W'].T
        pies      = model_params['pi']
        sigma     = model_params['sigma']

        # Read in data:
        my_y      = my_data['y']
        my_cand   = my_data['candidates']
        my_logpj  = my_suff_stat['logpj']
        my_N, D   = my_y.shape
        N         = comm.allreduce(my_N)

        state_mtx = self.state_matrix        # shape: (no_states, Hprime)
        state_abs = self.state_abs           # shape: (no_states,)
        no_states = len(state_abs)

        # Disable some warnings
        old_seterr = np.seterr(divide='ignore', under='ignore')

        # To compute et_loglike:
        my_ldenom_sum = 0.0
        ldenom_sum = 0.0

        # Precompute 
        T        = anneal['T'] 
        T_rho    = np.maximum(T, self.rho_T_bound)
        rho      = 1./(1.-1./T_rho)
        rho      = np.maximum(np.minimum(rho, self.rho_ubound), self.rho_lbound)
        beta     = 1./T
        pre1     = -1./2./sigma/sigma
        pil_bar  = np.log( pies/(1.-pies) )
        Wl       = accel.log(np.abs(W))
        Wrho     = accel.exp(rho * Wl)
        Wrhos    = np.sign(W) * Wrho
        Wsquared = W*W

        # Some asserts
        assert np.isfinite(pil_bar).all()
        assert np.isfinite(Wl).all()
        assert np.isfinite(Wrho).all()
        assert (Wrho > 1e-86).all()

        my_corr   = beta*((my_logpj).max(axis=1))            # shape: (my_N,)
        my_logpjb = beta*my_logpj - my_corr[:, None]         # shape: (my_N, no_states)
        my_pj     = accel.exp(my_logpj)                         # shape: (my_N, no_states)
        my_pjb    = accel.exp(my_logpjb)                        # shape: (my_N, no_states)

        # Precompute factor for pi update and ET cutting
        A_pi_gamma = 0.; B_pi_gamma = 0.
        for gp in range(0, self.gamma+1):
            a = comb(H, gp) * pies**gp * (1.-pies)**(H-gp)
            A_pi_gamma += a
            B_pi_gamma += gp * a

        # Truncate data
        if anneal['Ncut_factor'] > 0.0:
            tracing.tracepoint("M_step:truncating")
            my_logdenoms = accel.log(my_pjb.sum(axis=1)) + my_corr
            N_use = int(N * (1 - (1 - A_pi_gamma) * anneal['Ncut_factor']))

            cut_denom = parallel.allsort(my_logdenoms)[-N_use]
            my_sel,   = np.where(my_logdenoms >= cut_denom)
            my_N,     = my_sel.shape
            N_use     = comm.allreduce(my_N)
        else:
            my_N,_ = my_y.shape
            my_sel = np.arange(my_N)
            N_use  = N
            
        # Allocate suff-stat arrays
        my_Wp    = np.zeros_like(W)  # shape (H, D)
        my_Wq    = np.zeros_like(W)  # shape (H, D)
        my_pi    = 0.0               #
        my_sigma = 0.0               #

        # Iterate over all datapoints
        tracing.tracepoint("M_step:iterating...")
        dlog.append('N_use', N_use)
        for n in my_sel:
            y      = my_y[n,:]             # shape (D,)
            cand   = my_cand[n,:]          # shape (Hprime,)
            logpj  = my_logpj[n,:]         # shape (no_states,)
            logpjb = my_logpjb[n,:]        # shape (no_states,)
            pj     = my_pj[n,:]            # shape (no_states,)
            pjb    = my_pjb[n,:]           # shape (no_states,)

            this_Wp = np.zeros_like(W)    # numerator for W (current datapoint)   (H, D)
            this_Wq = np.zeros_like(W)    # denominator for W (current datapoint) (H, D)
            this_pi = 0.0                 # numerator for pi update (current datapoint)
            this_sigma = 0.0              # numerator for gamma update (current datapoint)

            # Zero active hidden causes
            # this_Wp += 0.     # nothing to do
            # this_Wq += 0.     # nothing to do
            # this_pi += 0.     # nothing to do
            this_sigma += pjb[0] * (y**2).sum()

            # One active hidden cause
            this_Wp    += (pjb[1:(H+1),None]) * y[None, :]
            this_Wq    += (pjb[1:(H+1),None])
            this_pi    +=  pjb[1:(H+1)].sum()
            this_sigma += (pjb[1:(H+1)] * ((W-y)**2).sum(axis=1)).sum()

            # Handle hidden states with more than 1 active cause
            W_     = W[cand]                                    # is (Hprime, D)
            Wl_    = Wl[cand]                                   # is (   "    ")
            Wrho_  = Wrho[cand]                                 # is (   "    ")
            Wrhos_ = Wrhos[cand]                                # is (   "    ")

            #Wbar   = calc_Wbar(state_mtx, W_)
            #Wlbar  = np.log(np.abs(Wbar))

            t0 = np.dot(state_mtx, Wrhos_)
            Wlbar   = accel.log(np.abs(t0)) / rho    # is (no_states, D)
            #Wlbar   = np.maximum(Wlbar, -9.21)
            Wbar    = np.sign(t0)*accel.exp(Wlbar)   # is (no_states, D)

            t = Wlbar[:, None, :]-Wl_[None, :, :]
            t = np.maximum(t, 0.)
            Aid = state_mtx[:,:, None] * accel.exp(logpjb[H+1:,None,None] - (rho-1)*t)
            Aid = Aid.sum(axis=0)

            #Aid = calc_Aid(logpjb[H+1:], W_, Wl_, state_mtx, Wbar, Wlbar, rho)

            #assert np.isfinite(Wlbar).all()
            #assert np.isfinite(Wbar).all()
            #assert np.isfinite(pjb).all()
            #assert np.isfinite(Aid).all()

            this_Wp[cand] += Aid * y[None, :]                     
            this_Wq[cand] += Aid
            this_pi       += (pjb[1+H:] * state_abs).sum()
            this_sigma    += (pjb[1+H:] * ((Wbar-y)**2).sum(axis=1)).sum()

            denom     = pjb.sum()
            my_Wp    += this_Wp / denom
            my_Wq    += this_Wq / denom
            my_pi    += this_pi / denom
            my_sigma += this_sigma / denom

            #self.tbl.append("logpj", logpj)
            #self.tbl.append("corr", my_corr[n])
            #self.tbl.append("denom", denom)
            #self.tbl.append("cand", cand)
            #self.tbl.append("Aid", Aid)

            my_ldenom_sum += accel.log(np.sum(accel.exp(logpj))) #For loglike computation

        # Calculate updated W
        if 'W' in self.to_learn:
            tracing.tracepoint("M_step:update W")

            Wp = np.empty_like(my_Wp)
            Wq = np.empty_like(my_Wq)

            assert np.isfinite(my_Wp).all()
            assert np.isfinite(my_Wq).all()

            comm.Allreduce( [my_Wp, MPI.DOUBLE], [Wp, MPI.DOUBLE] )
            comm.Allreduce( [my_Wq, MPI.DOUBLE], [Wq, MPI.DOUBLE] )


            # Make sure wo do not devide by zero
            tiny = self.tol
            Wq[Wq < tiny] = tiny

            # Calculate updated W
            W_new = Wp / Wq

            # Add inertia depending on Wq
            alpha = 2.5
            inertia = np.maximum(1. - accel.exp(-Wq / alpha), 0.2)
            W_new = inertia*W_new + (1-inertia)*W
            W_new = W_new.T
        else:
            W_new = W.T

        # Calculate updated pi
        if 'pi' in self.to_learn:
            tracing.tracepoint("M_step:update pi")

            assert np.isfinite(my_pi).all()
            pi_new = A_pi_gamma / B_pi_gamma * pies * comm.allreduce(my_pi) / N_use
        else:
            pi_new = pies

        # Calculate updated sigma
        if 'sigma' in self.to_learn:               # TODO: XXX see LinCA XXX (merge!)
            tracing.tracepoint("M_step:update sigma")

            assert np.isfinite(my_sigma).all()
            sigma_new = np.sqrt(comm.allreduce(my_sigma) / D / N_use)
        else:
            sigma_new = sigma

        # Put all together and compute (always) et_approx_likelihood
        ldenom_sum = comm.allreduce(my_ldenom_sum)
        lAi = (H * np.log(1. - pi_new)) - ((D/2) * np.log(2*pi)) -( D * np.log(sigma_new))

        # For practical and et approx reasons we use: sum of restected respons=1
        loglike_et = (lAi * N_use) + ldenom_sum

        # Restore np.seterr
        np.seterr(**old_seterr)

        return { 'W': W_new, 'pi': pi_new, 'sigma': sigma_new , 'Q':loglike_et}
Exemplo n.º 6
0
    pprint("  size of bars images:    %d x %d" % (size, size))
    pprint("  number of hiddens:      %d" % H)
    pprint("  saving results to:      %s" % output_path)
    pprint()

    my_data = model.generate_data(params_gt, N_train // comm.size)
    my_test_data = model.generate_data(params_gt, N_test // comm.size)

    # Configure DataLogger
    store_list = ('*')
    print_list = ('T', 'Q', 'pi', 'sigma', 'N', 'MAE')
    dlog.set_handler(print_list, TextPrinter)
    dlog.set_handler(print_list, StoreToTxt, output_path + '/terminal.txt')
    dlog.set_handler(store_list, StoreToH5, output_path + '/result.h5')

    dlog.append('Hprime_start', model.Hprime)
    dlog.append('gamma_start', model.gamma)

    model_params = model.standard_init(my_data)

    if 'anneal' in params:
        anneal = params.get('anneal')
    else:
        # Choose annealing schedule
        anneal = LinearAnnealing(50)
        anneal['T'] = [(0, 2.), (.7, 1.)]
        anneal['Ncut_factor'] = [(0, 0.), (2. / 3, 1.)]
        anneal['anneal_prior'] = False

    # Create and start EM annealing
    em = EM(model=model, anneal=anneal)