示例#1
0
def finalize_splits(nz, n_splits, splitted, Dts, Trace, nh, ns, kernel):
    
    new_nz = nz + n_splits
    if kernel.get_priors().shape[0] > 0:
        new_P = [row for row in kernel.get_state()]
        for _ in xrange(n_splits):
            new_P.append(kernel.get_priors())
    else:
        new_P = kernel.get_state()

    Trace[:, -1] = splitted
    
    #Populate new counts
    Count_zh_new = np.zeros(shape=(new_nz, nh), dtype='i4')
    Count_sz_new = np.zeros(shape=(ns, new_nz), dtype='i4')
    count_z_new = np.zeros(new_nz, dtype='i4')
    count_h_new = np.zeros(nh, dtype='i4')
    
    _learn.fast_populate(Trace, Count_zh_new, Count_sz_new, \
            count_h_new, count_z_new)
    
    new_stamps = StampLists(new_nz)
    for z in xrange(new_nz):
        idx = Trace[:, -1] == z
        topic_stamps = Dts[idx]
        new_stamps._extend(z, topic_stamps[:, -1])

    return Trace, Count_zh_new, Count_sz_new, \
            count_z_new, new_stamps, np.array(new_P)
示例#2
0
def finalize_merge(nz, to_merge, Dts, Trace, nh, ns, kernel):
    
    for z1, z2 in to_merge:
        idx = Trace[:, -1] == z2
        Trace[:, -1][idx] = z1
    
    if to_merge and kernel.get_priors().shape[0] > 0:
        new_P_dict = dict((i, row) for i, row in enumerate(kernel.get_state()))
        for z1, z2 in to_merge:
            del new_P_dict[z2]

        new_P = []
        for i in sorted(new_P_dict):
            new_P.append(new_P_dict[i])
    else:
        new_P = kernel.get_state()

    #Make sure new trace has contiguous ids
    new_assign = Trace[:, -1].copy()
    old_assign = Trace[:, -1].copy()
    if to_merge:
        new_nz = len(set(new_assign))
        for i, z in enumerate(set(new_assign)):
            idx = old_assign == z
            new_assign[idx] = i
    else:
        new_nz = nz
    Trace[:, -1] = new_assign

    #Populate new counts
    Count_zh_new = np.zeros(shape=(new_nz, nh), dtype='i4')
    Count_sz_new = np.zeros(shape=(ns, new_nz), dtype='i4')
    count_z_new = np.zeros(new_nz, dtype='i4')
    count_h_new = np.zeros(nh, dtype='i4')

    _learn.fast_populate(Trace, Count_zh_new, Count_sz_new, \
            count_h_new, count_z_new)
    
    new_stamps = StampLists(new_nz)
    for z in xrange(new_nz):
        idx = Trace[:, -1] == z
        topic_stamps = Dts[idx]
        new_stamps._extend(z, topic_stamps[:, -1])

    return Trace, Count_zh_new, Count_sz_new, \
            count_z_new, new_stamps, np.array(new_P)
示例#3
0
def work():
    """
    For MPI Slave
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank
    
    #pr = cProfile.Profile()
    #pr.enable()

    while True:
        status = MPI.Status()
        msg = comm.recv(source=MASTER, tag=MPI.ANY_TAG, status=status)
        event = status.Get_tag()

        if event == Msg.LEARN.value:
            comm.isend(rank, dest=MASTER, tag=Msg.STARTED.value)

            num_iter = msg

            Dts, Trace, Count_zh, Count_sz, count_h, count_z, \
                    alpha_zh, beta_zs, kernel = receive_workload(comm)
            fast_populate(Trace, Count_zh, Count_sz, count_h, \
                    count_z)
            sample(Dts, Trace, Count_zh, Count_sz, count_h, \
                    count_z, alpha_zh, beta_zs, kernel, num_iter, \
                    comm)
            
            comm.isend(rank, dest=MASTER, tag=Msg.FINISHED.value)
        elif event == Msg.SENDRESULTS.value:
            comm.Send([np.array(Trace[:, -1], order='C'), MPI.INT], dest=MASTER)
            comm.Send([Count_zh, MPI.INT], dest=MASTER)
            comm.Send([Count_sz, MPI.INT], dest=MASTER)
            comm.Send([count_h, MPI.INT], dest=MASTER)
            comm.Send([count_z, MPI.INT], dest=MASTER)
            comm.Send([kernel.get_state(), MPI.DOUBLE], dest=MASTER)
        elif event == Msg.STOP.value:
            break
        else:
            print('Unknown message received', msg, event, Msg(event))
示例#4
0
def work():
    comm = MPI.COMM_WORLD
    rank = comm.rank
    
    #pr = cProfile.Profile()
    #pr.enable()

    while True:
        status = MPI.Status()
        msg = comm.recv(source=MASTER, tag=MPI.ANY_TAG, status=status)
        event = status.Get_tag()

        if event == Msg.LEARN.value:
            comm.isend(rank, dest=MASTER, tag=Msg.STARTED.value)

            num_iter = msg

            Dts, Trace, Count_zh, Count_sz, count_h, count_z, \
                    alpha_zh, beta_zs, kernel = receive_workload(comm)
            fast_populate(Trace, Count_zh, Count_sz, count_h, \
                    count_z)
            sample(Dts, Trace, Count_zh, Count_sz, count_h, \
                    count_z, alpha_zh, beta_zs, kernel, num_iter, \
                    comm)
            
            comm.isend(rank, dest=MASTER, tag=Msg.FINISHED.value)
        elif event == Msg.SENDRESULTS.value:
            comm.Send([np.array(Trace[:, -1], order='C'), MPI.INT], dest=MASTER)
            comm.Send([Count_zh, MPI.INT], dest=MASTER)
            comm.Send([Count_sz, MPI.INT], dest=MASTER)
            comm.Send([count_h, MPI.INT], dest=MASTER)
            comm.Send([count_z, MPI.INT], dest=MASTER)
            comm.Send([kernel.get_state(), MPI.DOUBLE], dest=MASTER)
        elif event == Msg.STOP.value:
            break
        else:
            print('Unknown message received', msg, event, Msg(event))