コード例 #1
0
ファイル: train.py プロジェクト: gulamungon/SEQUENS
    def train_batch_fkn(lr, batch_count):
        #global batch_count
        global lr_first  # We need this so that load and save functions can
        lr_first = lr  # get it

        #batch_count += 1
        log.info("Batch count " + str(batch_count))

        [[X, Y, U], bad_tr_files, [tr_feats, tr_idx], it_batch_nb,
         it_ctrl_nb] = it_tr_que.get_batch()
        print "A"
        print batch_count
        print it_batch_nb
        print "B"
        assert (batch_count == it_batch_nb)
        #all_bad_utts += bad_tr_files

        YY = lab2matrix(Y.squeeze())
        WGT = labMat2weight(YY, P_eff)
        WGT_m = np.ones(Y.shape) / float(Y.shape[0])

        # Adjust learn rate if the desired one has changed
        log.info("Global lr: " + str(lr))
        #log.debug( " lr  b_pool: " + str( lr * lr_scale_b_pool ) )
        #log.debug( " lr  a_pool: " + str( lr * lr_scale_a_pool ) )
        #log.debug( " lr  dplda: " + str( lr * lr_scale_dplda ) )
        #log.debug( " lr  multi: " + str( lr * lr_scale_multi ) )

        if (piggyback):
            M_pb = tr_ivec_pb[U, :]
            L_multi, L_dplda, L, _ = train_function(tr_feats, tr_idx, M_pb,
                                                    M_pb, WGT, YY, tau, WGT_m,
                                                    Y, lr)
        else:
            info, L, _ = train_function(tr_feats, tr_idx, WGT, YY, tau, WGT_m,
                                        Y, lr)
        print "Total loss: " + str(L)
        print "Info: " + str(info)
        #print "DPLDA loss: " + str(L_dplda)
        #print "grads: " + str(grads)
        return L
コード例 #2
0
    def check_dev_multi_loss_acc():
        X1, C1 = load_feats_dev(dev_scp_info['utt2file'])
        L_m = dev_scp_info['utt2spk']
        WGT_m = np.ones(L_m.shape) / float(L_m.shape[0])

        L_b = lab2matrix(L_m.squeeze())
        WGT_b = labMat2weight(L_b, P_eff)

        l, l_m, l_b, C = sess.run(
            [loss_, loss_m_, loss_b_, C_m_], {
                X1_p: X1,
                C1_p: C1,
                WGT_m_p: WGT_m,
                L_m_p: L_m,
                WGT_b_p: WGT_b,
                L_b_p: L_b,
                is_test_p: True
            })
        P = np.argmax(C, axis=1)
        Acc = sum(P == L_m) / float(len(L_m))
        print()
        log.info("Loss %f, (Binary, %f, Multi %f), Accuracy: %f", l, l_b, l_m,
                 Acc)
        return l
コード例 #3
0
ファイル: train.py プロジェクト: gulamungon/SEQUENS
    def train_batch_fkn_mpi(lr):

        global batch_count
        global lr_first  # We need this so that load and save functions can
        lr_first = lr  # get it

        global time_all
        global clock_all
        global time_batch_prep
        global clock_batch_prep
        global time_embd
        global clock_embd
        global time_train_dplda
        global clock_train_dplda
        global time_train_grad_f2i
        global clock_train_grad_f2i
        global time_train_f2i
        global clock_train_f2i

        time_b_all = time.time()
        clock_b_all = time.clock()

        batch_count += 1

        time_b_batch_prep = time.time()
        clock_b_batch_prep = time.clock()
        [[X, Y, U], bad_tr_files, [tr_feats, tr_idx], it_batch_nb,
         it_ctrl_nb] = it_tr_que.get_batch()
        time_a_batch_prep = time.time()
        clock_a_batch_prep = time.clock()

        # This is a check that all machines are processing the same batch.
        # And that the first utterance index is the same. So if seed of it_tr
        # differ, it will be detected.
        assert (batch_count == it_batch_nb)
        if (is_master):
            log.info("Batch count " + str(batch_count))
            log.info("Global lr: " + str(lr))
            log.debug(" lr  b_pool: " + str(lr * lr_scale_b_pool))
            log.debug(" lr  a_pool: " + str(lr * lr_scale_a_pool))
            log.debug(" lr  dplda: " + str(lr * lr_scale_dplda))
            log.debug(" lr  multi: " + str(lr * lr_scale_multi))

            log.debug("Control number: " + str(it_ctrl_nb))
            for i in range(1, mpi_size):
                mpi_comm.send([batch_count, it_ctrl_nb], dest=i)
        else:
            batch_count_master, it_ctrl_nb_master = mpi_comm.recv(source=0)
            log.debug("Batch count " + str(batch_count))
            log.debug("Received batch number from Master: " +
                      str(batch_count_master))
            log.debug("Control number: " + str(it_ctrl_nb))
            log.debug("Received control number from Master: " +
                      str(it_ctrl_nb_master))
            assert (batch_count_master == batch_count)
            assert (it_ctrl_nb_master == it_ctrl_nb)

        all_bad_utts += bad_tr_files

        # Get the embeddings
        time_b_embd = time.time()
        clock_b_embd = time.clock()
        M = feat2embd(tr_feats, tr_idx)
        time_a_embd = time.time()
        clock_a_embd = time.clock()

        # If MPI, all except worker 0, sends their embeddings M to worker 0.
        if (not is_master):
            mpi_comm.send([M, Y], dest=0)
        else:
            job_indices = np.array([0, Y.shape[0]])
            for i in range(1, mpi_size):
                [M_received, Y_received] = mpi_comm.recv(source=i)
                log.debug("Received embeddings from worker %d" % i)

                M = np.concatenate((M, M_received), axis=0)
                Y = np.concatenate((Y, Y_received), axis=0)
                job_indices = np.concatenate((job_indices, Y_received.shape))

            job_indices = np.cumsum(job_indices)
            YY = lab2matrix(Y.squeeze())
            WGT = labMat2weight(YY, P_eff)
            WGT_m = np.ones(Y.shape) / float(Y.shape[0])

        # Train DPLDA part + get gradient of loss with respect to embeddings
        # This occurs only on worker 0.
        if (is_master):

            time_b_train_dplda = time.time()
            clock_b_train_dplda = time.clock()
            if (piggyback):
                [L_multi, L_dplda, L,
                 g] = train_function_i2s(M, m_PB, WGT, YY, tau, WGT_m, Y, lr)
            else:
                [L_multi, L_dplda, L,
                 g] = train_function_i2s(M, WGT, YY, tau, WGT_m, Y, lr)
            time_a_train_dplda = time.time()
            clock_a_train_dplda = time.clock()

            print "Total loss: " + str(L)

            # If MPI, worker 0 sends the gradients of embeddings wrt loss to the other workers
            for i in range(1, mpi_size):
                start_i = int(job_indices[i])
                end_i = int(job_indices[i + 1])
                mpi_comm.send([L, g[start_i:end_i, :]], dest=i)

            # Now reduce g for Master
            start_i = int(job_indices[0])
            end_i = int(job_indices[1])
            g = g[start_i:end_i, :]
        else:
            # Slaves recieves the above gradients
            [L, g] = mpi_comm.recv(source=0)
            log.debug("Received grads from worker 0 with shape %s" %
                      str(g.shape))

        # Calculate the gradients for the f2i parameters
        time_b_train_grad_f2i = time.time()
        clock_b_train_grad_f2i = time.clock()
        g_f2i = train_function_grad_f2i(tr_feats, tr_idx, g)
        time_a_train_grad_f2i = time.time()
        clock_a_train_grad_f2i = time.clock()

        # If MPI, all except worker 0, sends their f2i gradients, g_f2i, to worker 0.
        if (not is_master):
            mpi_comm.send(g_f2i, dest=0)
        else:
            for i in range(1, mpi_size):
                g_f2i_received = mpi_comm.recv(source=i)
                log.debug("Received grad_f2i from worker %d" % i)

                g_f2i = [p1 + p2 for p1, p2 in zip(g_f2i, g_f2i_received)
                         ]  # g_f2i + g_f2i_received

        # Even if MPI, this occurs only on worker 0.
        if (is_master):
            time_b_train_f2i = time.time()
            clock_b_train_f2i = time.clock()
            L_f2i = train_function_f2i(g_f2i, lr)
            time_a_train_f2i = time.time()
            clock_a_train_f2i = time.clock()

            # If MPI, Finally update the f2i part on each worker, except 0 were we just updated.
            ####        if ( is_master ):
            para = []
            for p in params_to_update_b_pool_ + params_to_update_a_pool_:
                para.append(sess.run(p))

            for i in range(1, mpi_size):
                mpi_comm.send(para, dest=i)
        else:
            para = mpi_comm.recv(source=0)
            para_ = params_to_update_b_pool_ + params_to_update_a_pool_
            for i in range(len(para)):
                sess.run(tf.assign(para_[i], para[i]))
            log.debug("Received model from worker 0")

        time_a_all = time.time()
        clock_a_all = time.clock()
        # Summarize the times
        time_all += time_a_all - time_b_all
        clock_all += clock_a_all - clock_b_all
        time_batch_prep += time_a_batch_prep - time_b_batch_prep
        clock_batch_prep += clock_a_batch_prep - clock_b_batch_prep
        time_embd += time_a_embd - time_b_embd
        clock_embd += clock_a_embd - clock_b_embd
        if is_master:
            time_train_dplda += time_a_train_dplda - time_b_train_dplda
            clock_train_dplda += clock_a_train_dplda - clock_b_train_dplda
        time_train_grad_f2i += time_a_train_grad_f2i - time_b_train_grad_f2i
        clock_train_grad_f2i += clock_a_train_grad_f2i - clock_b_train_grad_f2i
        if is_master:
            time_train_f2i += time_a_train_f2i - time_b_train_f2i
            clock_train_f2i += clock_a_train_f2i - clock_b_train_f2i

        log.debug("Average times so far: ")
        log.debug(" All steps time:                   " +
                  str(time_all / float(batch_count)) + "s")
        log.debug(" All steps clock:                  " +
                  str(clock_all / float(batch_count)) + "s")
        log.debug(" Batch preparation time:           " +
                  str(time_batch_prep / float(batch_count)) + "s")
        log.debug(" Batch preparation clock:          " +
                  str(clock_batch_prep / float(batch_count)) + "s")
        log.debug(" Extr. embd. time:                 " +
                  str(time_embd / float(batch_count)) + "s")
        log.debug(" Extr. embd clock:                 " +
                  str(clock_embd / float(batch_count)) + "s")
        if is_master:
            log.debug(" Train DPLDA time:                 " +
                      str(time_train_dplda / float(batch_count)) + "s")
            log.debug(" Train DPLDA clock:                " +
                      str(clock_train_dplda / float(batch_count)) + "s")
        log.debug(" Calculate grad f2i time:          " +
                  str(time_train_grad_f2i / float(batch_count)) + "s")
        log.debug(" Calculate grad f2i clock:         " +
                  str(clock_train_grad_f2i / float(batch_count)) + "s")
        if is_master:
            log.debug(" Train (given grad f2i) f2i time:  " +
                      str(time_train_f2i / float(batch_count)) + "s")
            log.debug(" Train (given grad f2i) f2i clock: " +
                      str(clock_train_f2i / float(batch_count)) + "s")

        return L