def finalize_splits(nz, n_splits, splitted, Dts, Trace, nh, ns, kernel): new_nz = nz + n_splits if kernel.get_priors().shape[0] > 0: new_P = [row for row in kernel.get_state()] for _ in xrange(n_splits): new_P.append(kernel.get_priors()) else: new_P = kernel.get_state() Trace[:, -1] = splitted #Populate new counts Count_zh_new = np.zeros(shape=(new_nz, nh), dtype='i4') Count_sz_new = np.zeros(shape=(ns, new_nz), dtype='i4') count_z_new = np.zeros(new_nz, dtype='i4') count_h_new = np.zeros(nh, dtype='i4') _learn.fast_populate(Trace, Count_zh_new, Count_sz_new, \ count_h_new, count_z_new) new_stamps = StampLists(new_nz) for z in xrange(new_nz): idx = Trace[:, -1] == z topic_stamps = Dts[idx] new_stamps._extend(z, topic_stamps[:, -1]) return Trace, Count_zh_new, Count_sz_new, \ count_z_new, new_stamps, np.array(new_P)
def finalize_merge(nz, to_merge, Dts, Trace, nh, ns, kernel): for z1, z2 in to_merge: idx = Trace[:, -1] == z2 Trace[:, -1][idx] = z1 if to_merge and kernel.get_priors().shape[0] > 0: new_P_dict = dict((i, row) for i, row in enumerate(kernel.get_state())) for z1, z2 in to_merge: del new_P_dict[z2] new_P = [] for i in sorted(new_P_dict): new_P.append(new_P_dict[i]) else: new_P = kernel.get_state() #Make sure new trace has contiguous ids new_assign = Trace[:, -1].copy() old_assign = Trace[:, -1].copy() if to_merge: new_nz = len(set(new_assign)) for i, z in enumerate(set(new_assign)): idx = old_assign == z new_assign[idx] = i else: new_nz = nz Trace[:, -1] = new_assign #Populate new counts Count_zh_new = np.zeros(shape=(new_nz, nh), dtype='i4') Count_sz_new = np.zeros(shape=(ns, new_nz), dtype='i4') count_z_new = np.zeros(new_nz, dtype='i4') count_h_new = np.zeros(nh, dtype='i4') _learn.fast_populate(Trace, Count_zh_new, Count_sz_new, \ count_h_new, count_z_new) new_stamps = StampLists(new_nz) for z in xrange(new_nz): idx = Trace[:, -1] == z topic_stamps = Dts[idx] new_stamps._extend(z, topic_stamps[:, -1]) return Trace, Count_zh_new, Count_sz_new, \ count_z_new, new_stamps, np.array(new_P)
def work(): """ For MPI Slave """ comm = MPI.COMM_WORLD rank = comm.rank #pr = cProfile.Profile() #pr.enable() while True: status = MPI.Status() msg = comm.recv(source=MASTER, tag=MPI.ANY_TAG, status=status) event = status.Get_tag() if event == Msg.LEARN.value: comm.isend(rank, dest=MASTER, tag=Msg.STARTED.value) num_iter = msg Dts, Trace, Count_zh, Count_sz, count_h, count_z, \ alpha_zh, beta_zs, kernel = receive_workload(comm) fast_populate(Trace, Count_zh, Count_sz, count_h, \ count_z) sample(Dts, Trace, Count_zh, Count_sz, count_h, \ count_z, alpha_zh, beta_zs, kernel, num_iter, \ comm) comm.isend(rank, dest=MASTER, tag=Msg.FINISHED.value) elif event == Msg.SENDRESULTS.value: comm.Send([np.array(Trace[:, -1], order='C'), MPI.INT], dest=MASTER) comm.Send([Count_zh, MPI.INT], dest=MASTER) comm.Send([Count_sz, MPI.INT], dest=MASTER) comm.Send([count_h, MPI.INT], dest=MASTER) comm.Send([count_z, MPI.INT], dest=MASTER) comm.Send([kernel.get_state(), MPI.DOUBLE], dest=MASTER) elif event == Msg.STOP.value: break else: print('Unknown message received', msg, event, Msg(event))
def work(): comm = MPI.COMM_WORLD rank = comm.rank #pr = cProfile.Profile() #pr.enable() while True: status = MPI.Status() msg = comm.recv(source=MASTER, tag=MPI.ANY_TAG, status=status) event = status.Get_tag() if event == Msg.LEARN.value: comm.isend(rank, dest=MASTER, tag=Msg.STARTED.value) num_iter = msg Dts, Trace, Count_zh, Count_sz, count_h, count_z, \ alpha_zh, beta_zs, kernel = receive_workload(comm) fast_populate(Trace, Count_zh, Count_sz, count_h, \ count_z) sample(Dts, Trace, Count_zh, Count_sz, count_h, \ count_z, alpha_zh, beta_zs, kernel, num_iter, \ comm) comm.isend(rank, dest=MASTER, tag=Msg.FINISHED.value) elif event == Msg.SENDRESULTS.value: comm.Send([np.array(Trace[:, -1], order='C'), MPI.INT], dest=MASTER) comm.Send([Count_zh, MPI.INT], dest=MASTER) comm.Send([Count_sz, MPI.INT], dest=MASTER) comm.Send([count_h, MPI.INT], dest=MASTER) comm.Send([count_z, MPI.INT], dest=MASTER) comm.Send([kernel.get_state(), MPI.DOUBLE], dest=MASTER) elif event == Msg.STOP.value: break else: print('Unknown message received', msg, event, Msg(event))