Beispiel #1
0
    def update_terminals(self):

        self.terminals = sp.zeros(self.vertices.shape, dtype='int')
        self.terminals[0,
                       sp.where(
                           sp.sum(sp.tril(self.edges), axis=1) == 0)[0]] = 1
        self.terminals[1,
                       sp.where(
                           sp.sum(sp.triu(self.edges), axis=1) == 0)[0]] = 1
Beispiel #2
0
def Calculate_NCG(X, Pi, alpha, beta1, beta2, d_old, g_norm_old):

    # Probability of observing an edge between clusters
    eps = 1e-10

    N = X.shape[0]
    k = Pi.shape[1]

    psi = sc.special.psi
    Y = sc.ones((N, N)) - sc.identity(N) - X

    d_log_B1 = psi(beta1) - psi(beta1 + beta2)
    d_log_B2 = psi(beta2) - psi(beta1 + beta2)

    d_log_B1 = sc.diag(
        sc.diag(d_log_B1)) + sc.log(eps) * sc.triu(sc.ones(
            (k, k)), 1) + sc.log(eps) * sc.tril(sc.ones((k, k)), -1)
    d_log_B2 = sc.diag(
        sc.diag(d_log_B2)) + sc.log(1 - eps) * sc.triu(sc.ones(
            (k, k)), 1) + sc.log(1 - eps) * sc.tril(sc.ones((k, k)), -1)

    # L_prime_i,j = dL / dPi_i,j
    L_prime = X.dot(Pi.dot(d_log_B1.T)) + Y.dot(Pi.dot(d_log_B2.T)) \
        + sc.outer(sc.ones(N), psi(alpha)) - sc.log(Pi) - sc.outer(sc.ones(N), sc.ones(k))

    u = sc.concatenate([sc.identity(k - 1), -sc.ones((1, k - 1))], axis=0)

    # Natural gradient
    g = L_prime.dot(u)

    # Norm of natural gradient
    g_norm = 0
    for i in range(N):
        l = L_prime[i, :]
        p = Pi[i, :]
        B = sc.diag(p) - sc.outer(p, p)
        g_norm += l.dot(B.dot(l))

# Natural Conjugate Gradient
    d = g + g_norm / g_norm_old * d_old

    return d, g_norm
Beispiel #3
0
 def __init__(self,
              matA,
              matB,
              xguess=None,
              eps_flux=EPS_INNER,
              eps_source=EPS_OUTER,
              eps_k=EPS_OUTER):
     super().__init__(matA, matB, xguess, eps_flux, eps_source, eps_k)
     self.matL = scipy.tril(matA)
     self.matU = matA - self.matL
     del matA
Beispiel #4
0
def factLU(A):
    n=len(A)
    for k in range(n-1):
        if s.absolute(A[k][k])<1e-20:
            print('Pivote cercano a cero en k=%d' % k)
            L=[[0.0]*n for i in range(n)]
            U=[[0.0]*n for i in range(n)]
            return L,U
        else:
            for i in range(k+1,n):
                A[i][k] = A[i][k]/A[k][k]
                for j in range(k+1,n):
                    A[i][j] = A[i][j]-A[i][k]*A[k][j]

    #==================================
    #OBSERVE QUE: 
    #Las siguientes lineas en realidad no hacen falta
    U=s.triu(A)
    L=s.tril(A,-1)
    for i in range(n):
        L[i][i] =1.0
    return L,U
Beispiel #5
0
def Log_Likelihood(X, Pi, alpha, beta1, beta2):

    # Probability of observing an edge between clusters
    eps = 1e-10

    k = Pi.shape[1]
    N = X.shape[0]
    ones = sc.ones((k, k))
    Y = sc.ones((N, N)) - sc.identity(N) - X
    Z = sc.log(eps) * Pi.T.dot(X.dot(Pi)) + sc.log(1 - eps) * Pi.T.dot(
        Y.dot(Pi))
    Z = sc.tril(Z, -1)

    A = sc.sum(Pi * sc.log(Pi))
    B_alpha = sc.log(nbeta(alpha) / nbeta(sc.ones(alpha.shape[0])))

    B_beta = sc.special.beta(beta1, beta2) / sc.special.beta(1, 1)
    B_beta = sc.diag(sc.log(B_beta))

    L = sc.sum(Z) + A + B_alpha + sc.sum(B_beta)

    return L
Beispiel #6
0
def gen_unrelated_eur_1k_data(input_file='/home/bjarni/TheHonestGene/faststorage/1Kgenomes/phase3/1k_genomes_hg.hdf5' ,
                              out_file='/home/bjarni/PCMA/faststorage/1_DATA/1k_genomes/1K_genomes_phase3_EUR_unrelated.hdf5',
                              maf_thres=0.01, max_relatedness=0.05, K_thinning_frac=0.1, debug=False):
    h5f = h5py.File(input_file)
    num_indivs = len(h5f['indivs']['continent'])
    eur_filter = h5f['indivs']['continent'][...] == 'EUR'
    num_eur_indivs = sp.sum(eur_filter)
    print 'Number of European individuals: %d', num_eur_indivs
    K = sp.zeros((num_eur_indivs, num_eur_indivs), dtype='single')
    num_snps = 0
    std_thres = sp.sqrt(2.0 * (1 - maf_thres) * (maf_thres))

    print 'Calculating kinship'
    for chrom in range(1, 23):
        print 'Working on Chromosome %d' % chrom
        chrom_str = 'chr%d' % chrom
        
        print 'Loading SNPs and data'
        snps = sp.array(h5f[chrom_str]['calldata']['snps'][...], dtype='int8')

        print 'Loading NTs'
        ref_nts = h5f[chrom_str]['variants']['REF'][...]
        alt_nts = h5f[chrom_str]['variants']['ALT'][...]
        
        print 'Filtering multi-allelic SNPs'
        multi_allelic_filter = sp.negative(h5f[chrom_str]['variants']['MULTI_ALLELIC'][...])
        snps = snps[multi_allelic_filter]
        ref_nts = ref_nts[multi_allelic_filter]
        alt_nts = alt_nts[multi_allelic_filter]


        if K_thinning_frac < 1:
            print 'Thinning SNPs for kinship calculation'
            thinning_filter = sp.random.random(len(snps)) < K_thinning_frac
            snps = snps[thinning_filter]
            alt_nts = alt_nts[thinning_filter]
            ref_nts = ref_nts[thinning_filter]

        print 'Filter SNPs with missing NT information'
        nt_filter = sp.in1d(ref_nts, ok_nts)
        nt_filter = nt_filter * sp.in1d(alt_nts, ok_nts)
        if sp.sum(nt_filter) < len(nt_filter):
            snps = snps[nt_filter]

        print 'Filtering non-European individuals'
        snps = snps[:, eur_filter]

        print 'Filtering SNPs with MAF <', maf_thres
        snp_stds = sp.std(snps, 1)
        maf_filter = snp_stds.flatten() > std_thres
        snps = snps[maf_filter]
        snp_stds = snp_stds[maf_filter]
        
        print '%d SNPs remaining after all filtering steps.' % len(snps)

        print 'Normalizing SNPs'
        snp_means = sp.mean(snps, 1)
        norm_snps = (snps - snp_means[sp.newaxis].T) / snp_stds[sp.newaxis].T
        
        print 'Updating kinship'        
        K += sp.dot(norm_snps.T, norm_snps)
        num_snps += len(norm_snps)
        assert sp.isclose(sp.sum(sp.diag(K)) / (num_snps * num_eur_indivs), 1.0)

    K = K / float(num_snps)
    print 'Kinship calculation done using %d SNPs\n' % num_snps
    
    # Filter individuals
    print 'Filtering individuals'
    keep_indiv_set = set(range(num_eur_indivs))
    for i in range(num_eur_indivs):
        if i in keep_indiv_set:
            for j in range(i + 1, num_eur_indivs):
                if K[i, j] > max_relatedness:
                    if j in keep_indiv_set:
                        keep_indiv_set.remove(j)
    keep_indivs = list(keep_indiv_set)
    keep_indivs.sort()
    print 'Retained %d individuals\n' % len(keep_indivs)
    
    # Checking that everything is ok!
    K_ok = K[keep_indivs]
    K_ok = K_ok[:, keep_indivs]
    assert (K_ok - sp.tril(K_ok)).max() < max_relatedness

    indiv_filter = sp.zeros(num_indivs, dtype='bool8')
    indiv_filter[(sp.arange(num_indivs)[eur_filter])[keep_indivs]] = 1
    
    assert sp.sum(indiv_filter) == len(keep_indivs)
    
    # Store in new file
    print 'Now storing data.'
    oh5f = h5py.File(out_file, 'w')
    indiv_ids = h5f['indivs']['indiv_ids'][indiv_filter]
    oh5f.create_dataset('indiv_ids', data=indiv_ids)    
    for chrom in range(1, 23):
        print 'Working on Chromosome %d' % chrom
        chrom_str = 'chr%d' % chrom
        
        print 'Loading SNPs and data'
        snps = sp.array(h5f[chrom_str]['calldata']['snps'][...], dtype='int8')
        snp_ids = h5f[chrom_str]['variants']['ID'][...]
        positions = h5f[chrom_str]['variants']['POS'][...]

        print 'Loading NTs'
        ref_nts = h5f[chrom_str]['variants']['REF'][...]
        alt_nts = h5f[chrom_str]['variants']['ALT'][...]
        
        print 'Filtering multi-allelic SNPs'
        multi_allelic_filter = sp.negative(h5f[chrom_str]['variants']['MULTI_ALLELIC'][...])
        snps = snps[multi_allelic_filter]
        ref_nts = ref_nts[multi_allelic_filter]
        alt_nts = alt_nts[multi_allelic_filter]
        positions = positions[multi_allelic_filter]
        snp_ids = snp_ids[multi_allelic_filter]

        print 'Filter individuals'
        snps = snps[:, indiv_filter]
        
        print 'Filter SNPs with missing NT information'
        nt_filter = sp.in1d(ref_nts, ok_nts)
        nt_filter = nt_filter * sp.in1d(alt_nts, ok_nts)
        if sp.sum(nt_filter) < len(nt_filter):
            snps = snps[nt_filter]
            ref_nts = ref_nts[nt_filter]
            alt_nts = alt_nts[nt_filter]
            positions = positions[nt_filter]
            snp_ids = snp_ids[nt_filter]
        
        print 'filter monomorphic SNPs'
        snp_stds = sp.std(snps, 1)
        mono_morph_filter = snp_stds > 0
        snps = snps[mono_morph_filter]
        ref_nts = ref_nts[mono_morph_filter]
        alt_nts = alt_nts[mono_morph_filter]
        positions = positions[mono_morph_filter]
        snp_ids = snp_ids[mono_morph_filter]
        snp_stds = snp_stds[mono_morph_filter]

        snp_means = sp.mean(snps, 1)

        if debug:
            if K_thinning_frac < 1:
                print 'Thinning SNPs for kinship calculation'
                thinning_filter = sp.random.random(len(snps)) < K_thinning_frac
                k_snps = snps[thinning_filter]
                k_snp_stds = snp_stds[thinning_filter]

    
            print 'Filtering SNPs with MAF <', maf_thres
            maf_filter = k_snp_stds.flatten() > std_thres
            k_snps = k_snps[maf_filter]
            k_snp_stds = k_snp_stds[maf_filter]
            k_snp_means = sp.mean(k_snps)

            print 'Verifying that the Kinship makes sense'
            norm_snps = (k_snps - k_snp_means[sp.newaxis].T) / k_snp_stds[sp.newaxis].T
            K = sp.dot(norm_snps.T, norm_snps)
            num_snps += len(norm_snps)
            if sp.isclose(sp.sum(sp.diag(K)) / (num_snps * num_eur_indivs), 1.0) and (K - sp.tril(K)).max() < (max_relatedness * 1.5):
                print 'It looks OK!'
            else:
                raise Exception('Kinship looks wrong?')
        

        nts = sp.array([[nt1, nt2] for nt1, nt2 in izip(ref_nts, alt_nts)])

        print 'Writing to disk'
        cg = oh5f.create_group(chrom_str)
        cg.create_dataset('snps', data=snps)
        cg.create_dataset('snp_means', data=snp_means[sp.newaxis].T)
        cg.create_dataset('snp_stds', data=snp_stds[sp.newaxis].T)
        cg.create_dataset('snp_ids', data=snp_ids)
        cg.create_dataset('positions', data=positions)
        cg.create_dataset('nts', data=nts)
        oh5f.flush()
        print 'Done writing to disk'
        
#         centimorgans = h5f[chrom_str]['centimorgans'][...]
#         cg.create_dataset('centimorgans',data=centimorgans)
#         
#         centimorgan_rates = h5f[chrom_str]['centimorgan_rates'][...]
#         cg.create_dataset('centimorgan_rates',data=centimorgan_rates)
        
    oh5f.close()
    h5f.close()
    print 'Done'
Beispiel #7
0
def gen_unrelated_eur_1k_data(
        input_file='/home/bjarni/TheHonestGene/faststorage/1Kgenomes/phase3/1k_genomes_hg.hdf5',
        out_file='/home/bjarni/PCMA/faststorage/1_DATA/1k_genomes/1K_genomes_phase3_EUR_unrelated.hdf5',
        maf_thres=0.01,
        max_relatedness=0.05,
        K_thinning_frac=0.1,
        debug=False):
    h5f = h5py.File(input_file)
    num_indivs = len(h5f['indivs']['continent'])
    eur_filter = h5f['indivs']['continent'][...] == 'EUR'
    num_eur_indivs = sp.sum(eur_filter)
    print 'Number of European individuals: %d', num_eur_indivs
    K = sp.zeros((num_eur_indivs, num_eur_indivs), dtype='float64')
    num_snps = 0
    std_thres = sp.sqrt(2.0 * (1 - maf_thres) * (maf_thres))

    print 'Calculating kinship'
    for chrom in range(1, 23):
        print 'Working on Chromosome %d' % chrom
        chrom_str = 'chr%d' % chrom

        print 'Loading SNPs and data'
        snps = sp.array(h5f[chrom_str]['calldata']['snps'][...], dtype='int8')

        print 'Loading NTs'
        ref_nts = h5f[chrom_str]['variants']['REF'][...]
        alt_nts = h5f[chrom_str]['variants']['ALT'][...]

        print 'Filtering multi-allelic SNPs'
        multi_allelic_filter = sp.negative(
            h5f[chrom_str]['variants']['MULTI_ALLELIC'][...])
        snps = snps[multi_allelic_filter]
        ref_nts = ref_nts[multi_allelic_filter]
        alt_nts = alt_nts[multi_allelic_filter]

        if K_thinning_frac < 1:
            print 'Thinning SNPs for kinship calculation'
            thinning_filter = sp.random.random(len(snps)) < K_thinning_frac
            snps = snps[thinning_filter]
            alt_nts = alt_nts[thinning_filter]
            ref_nts = ref_nts[thinning_filter]

        print 'Filter SNPs with missing NT information'
        nt_filter = sp.in1d(ref_nts, ok_nts)
        nt_filter = nt_filter * sp.in1d(alt_nts, ok_nts)
        if sp.sum(nt_filter) < len(nt_filter):
            snps = snps[nt_filter]

        print 'Filtering non-European individuals'
        snps = snps[:, eur_filter]

        print 'Filtering SNPs with MAF <', maf_thres
        snp_stds = sp.std(snps, 1)
        maf_filter = snp_stds.flatten() > std_thres
        snps = snps[maf_filter]
        snp_stds = snp_stds[maf_filter]

        print '%d SNPs remaining after all filtering steps.' % len(snps)

        print 'Normalizing SNPs'
        snp_means = sp.mean(snps, 1)
        norm_snps = (snps - snp_means[sp.newaxis].T) / snp_stds[sp.newaxis].T

        print 'Updating kinship'
        K += sp.dot(norm_snps.T, norm_snps)
        num_snps += len(norm_snps)
        assert sp.isclose(
            sp.sum(sp.diag(K)) / (num_snps * num_eur_indivs), 1.0)

    K = K / float(num_snps)
    print 'Kinship calculation done using %d SNPs\n' % num_snps

    # Filter individuals
    print 'Filtering individuals'
    keep_indiv_set = set(range(num_eur_indivs))
    for i in range(num_eur_indivs):
        if i in keep_indiv_set:
            for j in range(i + 1, num_eur_indivs):
                if K[i, j] > max_relatedness:
                    if j in keep_indiv_set:
                        keep_indiv_set.remove(j)
    keep_indivs = list(keep_indiv_set)
    keep_indivs.sort()
    print 'Retained %d individuals\n' % len(keep_indivs)

    # Checking that everything is ok!
    K_ok = K[keep_indivs]
    K_ok = K_ok[:, keep_indivs]
    assert (K_ok - sp.tril(K_ok)).max() < max_relatedness

    indiv_filter = sp.zeros(num_indivs, dtype='bool8')
    indiv_filter[(sp.arange(num_indivs)[eur_filter])[keep_indivs]] = 1

    assert sp.sum(indiv_filter) == len(keep_indivs)

    # Store in new file
    print 'Now storing data.'
    oh5f = h5py.File(out_file, 'w')
    indiv_ids = h5f['indivs']['indiv_ids'][indiv_filter]
    oh5f.create_dataset('indiv_ids', data=indiv_ids)
    for chrom in range(1, 23):
        print 'Working on Chromosome %d' % chrom
        chrom_str = 'chr%d' % chrom

        print 'Loading SNPs and data'
        snps = sp.array(h5f[chrom_str]['calldata']['snps'][...], dtype='int8')
        snp_ids = h5f[chrom_str]['variants']['ID'][...]
        positions = h5f[chrom_str]['variants']['POS'][...]

        print 'Loading NTs'
        ref_nts = h5f[chrom_str]['variants']['REF'][...]
        alt_nts = h5f[chrom_str]['variants']['ALT'][...]

        print 'Filtering multi-allelic SNPs'
        multi_allelic_filter = sp.negative(
            h5f[chrom_str]['variants']['MULTI_ALLELIC'][...])
        snps = snps[multi_allelic_filter]
        ref_nts = ref_nts[multi_allelic_filter]
        alt_nts = alt_nts[multi_allelic_filter]
        positions = positions[multi_allelic_filter]
        snp_ids = snp_ids[multi_allelic_filter]

        print 'Filter individuals'
        snps = snps[:, indiv_filter]

        print 'Filter SNPs with missing NT information'
        nt_filter = sp.in1d(ref_nts, ok_nts)
        nt_filter = nt_filter * sp.in1d(alt_nts, ok_nts)
        if sp.sum(nt_filter) < len(nt_filter):
            snps = snps[nt_filter]
            ref_nts = ref_nts[nt_filter]
            alt_nts = alt_nts[nt_filter]
            positions = positions[nt_filter]
            snp_ids = snp_ids[nt_filter]

        print 'filter monomorphic SNPs'
        snp_stds = sp.std(snps, 1)
        mono_morph_filter = snp_stds > 0
        snps = snps[mono_morph_filter]
        ref_nts = ref_nts[mono_morph_filter]
        alt_nts = alt_nts[mono_morph_filter]
        positions = positions[mono_morph_filter]
        snp_ids = snp_ids[mono_morph_filter]
        snp_stds = snp_stds[mono_morph_filter]

        snp_means = sp.mean(snps, 1)

        if debug:
            if K_thinning_frac < 1:
                print 'Thinning SNPs for kinship calculation'
                thinning_filter = sp.random.random(len(snps)) < K_thinning_frac
                k_snps = snps[thinning_filter]
                k_snp_stds = snp_stds[thinning_filter]

            print 'Filtering SNPs with MAF <', maf_thres
            maf_filter = k_snp_stds.flatten() > std_thres
            k_snps = k_snps[maf_filter]
            k_snp_stds = k_snp_stds[maf_filter]
            k_snp_means = sp.mean(k_snps)

            print 'Verifying that the Kinship makes sense'
            norm_snps = (k_snps -
                         k_snp_means[sp.newaxis].T) / k_snp_stds[sp.newaxis].T
            K = sp.dot(norm_snps.T, norm_snps)
            num_snps += len(norm_snps)
            if sp.isclose(
                    sp.sum(sp.diag(K)) / (num_snps * num_eur_indivs),
                    1.0) and (K - sp.tril(K)).max() < (max_relatedness * 1.5):
                print 'It looks OK!'
            else:
                raise Exception('Kinship looks wrong?')

        nts = sp.array([[nt1, nt2] for nt1, nt2 in izip(ref_nts, alt_nts)])

        print 'Writing to disk'
        cg = oh5f.create_group(chrom_str)
        cg.create_dataset('snps', data=snps)
        cg.create_dataset('snp_means', data=snp_means[sp.newaxis].T)
        cg.create_dataset('snp_stds', data=snp_stds[sp.newaxis].T)
        cg.create_dataset('snp_ids', data=snp_ids)
        cg.create_dataset('positions', data=positions)
        cg.create_dataset('nts', data=nts)
        oh5f.flush()
        print 'Done writing to disk'


#         centimorgans = h5f[chrom_str]['centimorgans'][...]
#         cg.create_dataset('centimorgans',data=centimorgans)
#
#         centimorgan_rates = h5f[chrom_str]['centimorgan_rates'][...]
#         cg.create_dataset('centimorgan_rates',data=centimorgan_rates)

    oh5f.close()
    h5f.close()
    print 'Done'
Beispiel #8
0
    def from_gene(self, gene):
        
        for transcript_idx in range(len(gene.transcripts)):
            exon_start_end = gene.exons[transcript_idx]
            
            ### only one exon in the transcript
            if exon_start_end.shape[0] == 1:
                exon1_start = exon_start_end[0, 0]
                exon1_end = exon_start_end[0, 1]

                if self.vertices.shape[1] == 0:
                    self.vertices = sp.array([[exon1_start], [exon1_end]], dtype='int')
                    self.edges = sp.array([[0]], dtype='int')
                else:
                    self.vertices = sp.c_[self.vertices, [exon1_start, exon1_end]]
                    self.new_edge()
            ### more than one exon in the transcript
            else:
                for exon_idx in range(exon_start_end.shape[0] - 1):
                    exon1_start = exon_start_end[exon_idx , 0]
                    exon1_end = exon_start_end[exon_idx, 1]
                    exon2_start = exon_start_end[exon_idx + 1, 0]
                    exon2_end = exon_start_end[exon_idx + 1, 1]
          
                    if self.vertices.shape[1] == 0:
                        self.vertices = sp.array([[exon1_start, exon2_start], [exon1_end, exon2_end]], dtype='int')
                        self.edges = sp.array([[0, 1], [1, 0]], dtype='int')
                    else:
                        exon1_idx = -1
                        exon2_idx = -1
                        ### check if current exon already occurred
                        for idx in range(self.vertices.shape[1]):
                            if ((self.vertices[0, idx] == exon1_start) and (self.vertices[1, idx] == exon1_end)):
                                 exon1_idx = idx
                            if ((self.vertices[0, idx] == exon2_start) and (self.vertices[1, idx] == exon2_end)):
                                 exon2_idx = idx

                        ### both exons already occured -> only add an edge
                        if (exon1_idx != -1) and (exon2_idx != -1):
                            self.edges[exon1_idx, exon2_idx] = 1
                            self.edges[exon2_idx, exon1_idx] = 1
                        else:
                            ### 2nd exon occured
                            if ((exon1_idx == -1) and (exon2_idx != -1)):
                                self.vertices = sp.c_[self.vertices, [exon1_start, exon1_end]]
                                self.new_edge()
                                self.edges[exon2_idx, -1] = 1
                                self.edges[-1, exon2_idx] = 1
                            ### 1st exon occured
                            elif ((exon2_idx == -1) and (exon1_idx != -1)):
                                self.vertices = sp.c_[self.vertices, [exon2_start, exon2_end]]
                                self.new_edge()
                                self.edges[exon1_idx, -1] = 1
                                self.edges[-1, exon1_idx] = 1
                            ### no exon occured
                            else:
                                assert((exon1_idx == -1) and (exon2_idx == -1))
                                self.vertices = sp.c_[self.vertices, [exon1_start, exon1_end]]
                                self.vertices = sp.c_[self.vertices, [exon2_start, exon2_end]]
                                self.new_edge()
                                self.new_edge()
                                self.edges[-2, -1] = 1
                                self.edges[-1, -2] = 1

        ### take care of the sorting by exon start
        s_idx = sp.argsort(self.vertices[0, :])
        self.vertices = self.vertices[:, s_idx]
        self.edges = self.edges[s_idx, :][:, s_idx]
        self.terminals = sp.zeros(self.vertices.shape, dtype='int')
        self.terminals[0, sp.where(sp.tril(self.edges).sum(axis=1) == 0)[0]] = 1
        self.terminals[1, sp.where(sp.triu(self.edges).sum(axis=1) == 0)[0]] = 1
Beispiel #9
0
 def update_terminals(self):
     
     self.terminals = sp.zeros(self.vertices.shape, dtype='int')
     self.terminals[0, sp.where(sp.sum(sp.tril(self.edges), axis=1) == 0)[0]] = 1
     self.terminals[1, sp.where(sp.sum(sp.triu(self.edges), axis=1) == 0)[0]] = 1
Beispiel #10
0
    def newton(self, v=None, maxit=100, tol=EPS, verbose=False,
               gauss_seidel=False):
        """Solve Bellman equations via Newton method (policy iteration)

        Parameters
        --------------
        v : array, shape (n, ), optional
           Initial guess for values.
        x : array, shape (n, ), optional
           Initial guess for policy.
        maxit : int, optional
           Maximum number of iterations
        tol : float, optional
           Convergence tolerance
        gauss_seidel : bool, optional
           Use Gauss-Seidel to solve the linear equation.

        Returns
        ------------
        info : int
            Exit status. 0 if converged. -1 if not.
        t : int
            Number of iterations
        relres : float
            Residual variance
        v : array, shape (n, )
        x : array, shape 
        pstar : array, shape

        Notes
        --------

        Also called policy iteration.

        """
        if v is None:
            v = sp.zeros(self.n)
        ## Set initial values of x to such
        x = sp.zeros(self.n) 
        info = -1
        t = 0
        for it in range(maxit):
            t += 1
            xold = x.copy()
            v, x = self.valmax(v)
            pstar, fstar, ind = self.valpol(x)
            Q = pstar * self.discount
            eyeminus(Q)
            if not gauss_seidel:
                vold = v.copy()
                v = la.solve(Q, fstar)
                relres = la.norm(v - vold)
            else:
                ## Gauss Seidel
                L = sp.tril(Q)
                dv = la.solve(L, fstar - sp.dot(Q, v))
                relres = la.norm(dv)
                v += dv
            if verbose:
                print("%d, %f" % (it, relres))
            if sp.all(x == xold):
                info = 0
                break
        return (info, t, relres, v, x, pstar)