def train_batch_fkn(lr, batch_count): #global batch_count global lr_first # We need this so that load and save functions can lr_first = lr # get it #batch_count += 1 log.info("Batch count " + str(batch_count)) [[X, Y, U], bad_tr_files, [tr_feats, tr_idx], it_batch_nb, it_ctrl_nb] = it_tr_que.get_batch() print "A" print batch_count print it_batch_nb print "B" assert (batch_count == it_batch_nb) #all_bad_utts += bad_tr_files YY = lab2matrix(Y.squeeze()) WGT = labMat2weight(YY, P_eff) WGT_m = np.ones(Y.shape) / float(Y.shape[0]) # Adjust learn rate if the desired one has changed log.info("Global lr: " + str(lr)) #log.debug( " lr b_pool: " + str( lr * lr_scale_b_pool ) ) #log.debug( " lr a_pool: " + str( lr * lr_scale_a_pool ) ) #log.debug( " lr dplda: " + str( lr * lr_scale_dplda ) ) #log.debug( " lr multi: " + str( lr * lr_scale_multi ) ) if (piggyback): M_pb = tr_ivec_pb[U, :] L_multi, L_dplda, L, _ = train_function(tr_feats, tr_idx, M_pb, M_pb, WGT, YY, tau, WGT_m, Y, lr) else: info, L, _ = train_function(tr_feats, tr_idx, WGT, YY, tau, WGT_m, Y, lr) print "Total loss: " + str(L) print "Info: " + str(info) #print "DPLDA loss: " + str(L_dplda) #print "grads: " + str(grads) return L
def check_dev_multi_loss_acc(): X1, C1 = load_feats_dev(dev_scp_info['utt2file']) L_m = dev_scp_info['utt2spk'] WGT_m = np.ones(L_m.shape) / float(L_m.shape[0]) L_b = lab2matrix(L_m.squeeze()) WGT_b = labMat2weight(L_b, P_eff) l, l_m, l_b, C = sess.run( [loss_, loss_m_, loss_b_, C_m_], { X1_p: X1, C1_p: C1, WGT_m_p: WGT_m, L_m_p: L_m, WGT_b_p: WGT_b, L_b_p: L_b, is_test_p: True }) P = np.argmax(C, axis=1) Acc = sum(P == L_m) / float(len(L_m)) print() log.info("Loss %f, (Binary, %f, Multi %f), Accuracy: %f", l, l_b, l_m, Acc) return l
def train_batch_fkn_mpi(lr): global batch_count global lr_first # We need this so that load and save functions can lr_first = lr # get it global time_all global clock_all global time_batch_prep global clock_batch_prep global time_embd global clock_embd global time_train_dplda global clock_train_dplda global time_train_grad_f2i global clock_train_grad_f2i global time_train_f2i global clock_train_f2i time_b_all = time.time() clock_b_all = time.clock() batch_count += 1 time_b_batch_prep = time.time() clock_b_batch_prep = time.clock() [[X, Y, U], bad_tr_files, [tr_feats, tr_idx], it_batch_nb, it_ctrl_nb] = it_tr_que.get_batch() time_a_batch_prep = time.time() clock_a_batch_prep = time.clock() # This is a check that all machines are processing the same batch. # And that the first utterance index is the same. So if seed of it_tr # differ, it will be detected. assert (batch_count == it_batch_nb) if (is_master): log.info("Batch count " + str(batch_count)) log.info("Global lr: " + str(lr)) log.debug(" lr b_pool: " + str(lr * lr_scale_b_pool)) log.debug(" lr a_pool: " + str(lr * lr_scale_a_pool)) log.debug(" lr dplda: " + str(lr * lr_scale_dplda)) log.debug(" lr multi: " + str(lr * lr_scale_multi)) log.debug("Control number: " + str(it_ctrl_nb)) for i in range(1, mpi_size): mpi_comm.send([batch_count, it_ctrl_nb], dest=i) else: batch_count_master, it_ctrl_nb_master = mpi_comm.recv(source=0) log.debug("Batch count " + str(batch_count)) log.debug("Received batch number from Master: " + str(batch_count_master)) log.debug("Control number: " + str(it_ctrl_nb)) log.debug("Received control number from Master: " + str(it_ctrl_nb_master)) assert (batch_count_master == batch_count) assert (it_ctrl_nb_master == it_ctrl_nb) all_bad_utts += bad_tr_files # Get the embeddings time_b_embd = time.time() clock_b_embd = time.clock() M = feat2embd(tr_feats, tr_idx) time_a_embd = time.time() clock_a_embd = time.clock() # If MPI, all except worker 0, sends their embeddings M to worker 0. if (not is_master): mpi_comm.send([M, Y], dest=0) else: job_indices = np.array([0, Y.shape[0]]) for i in range(1, mpi_size): [M_received, Y_received] = mpi_comm.recv(source=i) log.debug("Received embeddings from worker %d" % i) M = np.concatenate((M, M_received), axis=0) Y = np.concatenate((Y, Y_received), axis=0) job_indices = np.concatenate((job_indices, Y_received.shape)) job_indices = np.cumsum(job_indices) YY = lab2matrix(Y.squeeze()) WGT = labMat2weight(YY, P_eff) WGT_m = np.ones(Y.shape) / float(Y.shape[0]) # Train DPLDA part + get gradient of loss with respect to embeddings # This occurs only on worker 0. if (is_master): time_b_train_dplda = time.time() clock_b_train_dplda = time.clock() if (piggyback): [L_multi, L_dplda, L, g] = train_function_i2s(M, m_PB, WGT, YY, tau, WGT_m, Y, lr) else: [L_multi, L_dplda, L, g] = train_function_i2s(M, WGT, YY, tau, WGT_m, Y, lr) time_a_train_dplda = time.time() clock_a_train_dplda = time.clock() print "Total loss: " + str(L) # If MPI, worker 0 sends the gradients of embeddings wrt loss to the other workers for i in range(1, mpi_size): start_i = int(job_indices[i]) end_i = int(job_indices[i + 1]) mpi_comm.send([L, g[start_i:end_i, :]], dest=i) # Now reduce g for Master start_i = int(job_indices[0]) end_i = int(job_indices[1]) g = g[start_i:end_i, :] else: # Slaves recieves the above gradients [L, g] = mpi_comm.recv(source=0) log.debug("Received grads from worker 0 with shape %s" % str(g.shape)) # Calculate the gradients for the f2i parameters time_b_train_grad_f2i = time.time() clock_b_train_grad_f2i = time.clock() g_f2i = train_function_grad_f2i(tr_feats, tr_idx, g) time_a_train_grad_f2i = time.time() clock_a_train_grad_f2i = time.clock() # If MPI, all except worker 0, sends their f2i gradients, g_f2i, to worker 0. if (not is_master): mpi_comm.send(g_f2i, dest=0) else: for i in range(1, mpi_size): g_f2i_received = mpi_comm.recv(source=i) log.debug("Received grad_f2i from worker %d" % i) g_f2i = [p1 + p2 for p1, p2 in zip(g_f2i, g_f2i_received) ] # g_f2i + g_f2i_received # Even if MPI, this occurs only on worker 0. if (is_master): time_b_train_f2i = time.time() clock_b_train_f2i = time.clock() L_f2i = train_function_f2i(g_f2i, lr) time_a_train_f2i = time.time() clock_a_train_f2i = time.clock() # If MPI, Finally update the f2i part on each worker, except 0 were we just updated. #### if ( is_master ): para = [] for p in params_to_update_b_pool_ + params_to_update_a_pool_: para.append(sess.run(p)) for i in range(1, mpi_size): mpi_comm.send(para, dest=i) else: para = mpi_comm.recv(source=0) para_ = params_to_update_b_pool_ + params_to_update_a_pool_ for i in range(len(para)): sess.run(tf.assign(para_[i], para[i])) log.debug("Received model from worker 0") time_a_all = time.time() clock_a_all = time.clock() # Summarize the times time_all += time_a_all - time_b_all clock_all += clock_a_all - clock_b_all time_batch_prep += time_a_batch_prep - time_b_batch_prep clock_batch_prep += clock_a_batch_prep - clock_b_batch_prep time_embd += time_a_embd - time_b_embd clock_embd += clock_a_embd - clock_b_embd if is_master: time_train_dplda += time_a_train_dplda - time_b_train_dplda clock_train_dplda += clock_a_train_dplda - clock_b_train_dplda time_train_grad_f2i += time_a_train_grad_f2i - time_b_train_grad_f2i clock_train_grad_f2i += clock_a_train_grad_f2i - clock_b_train_grad_f2i if is_master: time_train_f2i += time_a_train_f2i - time_b_train_f2i clock_train_f2i += clock_a_train_f2i - clock_b_train_f2i log.debug("Average times so far: ") log.debug(" All steps time: " + str(time_all / float(batch_count)) + "s") log.debug(" All steps clock: " + str(clock_all / float(batch_count)) + "s") log.debug(" Batch preparation time: " + str(time_batch_prep / float(batch_count)) + "s") log.debug(" Batch preparation clock: " + str(clock_batch_prep / float(batch_count)) + "s") log.debug(" Extr. embd. time: " + str(time_embd / float(batch_count)) + "s") log.debug(" Extr. embd clock: " + str(clock_embd / float(batch_count)) + "s") if is_master: log.debug(" Train DPLDA time: " + str(time_train_dplda / float(batch_count)) + "s") log.debug(" Train DPLDA clock: " + str(clock_train_dplda / float(batch_count)) + "s") log.debug(" Calculate grad f2i time: " + str(time_train_grad_f2i / float(batch_count)) + "s") log.debug(" Calculate grad f2i clock: " + str(clock_train_grad_f2i / float(batch_count)) + "s") if is_master: log.debug(" Train (given grad f2i) f2i time: " + str(time_train_f2i / float(batch_count)) + "s") log.debug(" Train (given grad f2i) f2i clock: " + str(clock_train_f2i / float(batch_count)) + "s") return L