示例#1
0
def main():
    """Train a model using lines of text contained in a file
    and evaluates the model.
    """

    #read golve vecs
    words, word_to_index, index_to_word, word_to_vec_map = read_glove_vecs(
        EMBEDDING_FILE)
    #create word embedding matrix
    embedding_matrix = create_emb_matrix(word_to_index, word_to_vec_map)
    print('shape of embedding_matrix:', embedding_matrix.shape)

    #load trainig text from a file
    utterances = load_text_data(TEXT_FILE)
    print(utterances[0])

    #create an instance of Punctutor and create training data
    punctuator = Punctuator(word_to_index, None)
    X, Y = punctuator.create_training_data(utterances, False)

    #if a model already exists, load the model
    if os.path.isfile(MODEL_FILE):
        punctuator.load_model(MODEL_FILE)
    else:
        model = BidirectionalGruWithGru.create_model(
            input_shape=(X.shape[1], ),
            embedding_matrix=embedding_matrix,
            vocab_len=len(word_to_index),
            n_d1=128,
            n_d2=128,
            n_c=len(punctuator.labels))
        print(model.summary())
        punctuator.__model__ = model

    #if the model has been already trained, use the pre-trained weights
    if os.path.isfile(WEIGHTS_FILE):
        punctuator.load_weights(WEIGHTS_FILE)

    #shuffle the training data
    shuffle(X, Y)

    denom_Y = Y.swapaxes(0, 1).sum((0, 1))
    print('Summary of Y:', denom_Y)

    print('shape of X:', X.shape)
    print(X[0:10])
    print('shape of Y:', Y.shape)
    print(Y[0:10])

    #define optimizer and compile the model
    opt = Adam(lr=0.007, beta_1=0.9, beta_2=0.999, decay=0.01)
    punctuator.compile(opt,
                       loss='categorical_crossentropy',
                       metrics=['accuracy'])

    #split the training data into training set, test set, and dev set
    t_size = int(X.shape[0] * 0.9)
    train_X, train_Y = X[:t_size], Y[:t_size]
    test_X, test_Y = X[t_size:-DEV_SIZE], Y[t_size:-DEV_SIZE]
    dev_X, dev_Y = X[-DEV_SIZE:], Y[-DEV_SIZE:]

    print(train_Y.swapaxes(0, 1).sum((0, 1)))
    print(test_Y.swapaxes(0, 1).sum((0, 1)))

    #train the model
    punctuator.fit([train_X], train_Y, batch_size=BATCH, epochs=EPOCH)
    punctuator.save_model(MODEL_FILE)
    punctuator.save_weights(WEIGHTS_FILE)

    #evaluate the model on the dev set (or the test set)
    for i, example in enumerate(dev_X):
        prediction = punctuator.predict(example)
        punctuator.check_result(prediction, dev_Y[i])

    #manually evaluate the model on an example
    examples = [
        "good morning chairman who I saw and members of the committee it's my pleasure to be here today I'm Elizabeth Ackles director of the office of rate payer advocates and I appreciate the chance to present on oris key activities from 2017 I have a short presentation and I'm going to move through it really quickly because you've had a long morning already and be happy to answer any questions that you have"
    ]
    for example in examples:
        words = example.split()
        x = punctuator.create_live_data(words)
        print x
        for s in x:
            print s
            prediction = punctuator.predict(s)
            result = punctuator.add_punctuation(prediction, words)
            print(result)
示例#2
0
def main():
    """Train a model using lines of text contained in a file
    and evaluates the model.
    """

    #read golve vecs
    #words, word_to_index, index_to_word, word_to_vec_map = read_glove_vecs(EMBEDDING_FILE)
    #create word embedding matrix
    #embedding_matrix = create_emb_matrix(word_to_index, word_to_vec_map)
    embedding_matrix = None
    #print('shape of embedding_matrix:', embedding_matrix.shape)

    #load trainig text from a file
    utterances = load_text_data(TEXT_FILE)
    punctuator = Punctuator(None, None)
    X, Y = punctuator.create_training_data(utterances[:3], False)
    print(X.shape)
    print(X.shape[1])
    print(Y.shape)

    #if a model already exists, load the model
    if os.path.isfile(MODEL_FILE) and False:
        punctuator.load_model(MODEL_FILE)
    else:
        model = BidirectionalGruWithGru.create_model(input_shape=(
            X.shape[1],
            X.shape[2],
        ),
                                                     embedding_matrix=None,
                                                     vocab_len=0,
                                                     n_d1=128,
                                                     n_d2=128,
                                                     n_c=len(
                                                         punctuator.labels))
        print(model.summary())
        punctuator.__model__ = model

    #if the model has been already trained, use the pre-trained weights
    if os.path.isfile(WEIGHTS_FILE):
        punctuator.load_weights(WEIGHTS_FILE)

    for i in range(100):
        shuffle(utterances)
        print(utterances[0])

        #create an instance of Punctutor and create training data
        X, Y = punctuator.create_training_data(utterances[:300000], False)

        #shuffle the training data
        shuffle(X, Y)

        denom_Y = Y.swapaxes(0, 1).sum((0, 1))
        print('Summary of Y:', denom_Y)

        print('shape of X:', X.shape)
        print(X[0:10])
        print('shape of Y:', Y.shape)
        print(Y[0:10])

        #define optimizer and compile the model
        opt = Adam(lr=0.007, beta_1=0.9, beta_2=0.999, decay=0.01)
        punctuator.compile(opt,
                           loss='categorical_crossentropy',
                           metrics=['accuracy'])

        #split the training data into training set, test set, and dev set
        t_size = int(X.shape[0] * 0.9)
        train_X, train_Y = X[:t_size], Y[:t_size]
        test_X, test_Y = X[t_size:-DEV_SIZE], Y[t_size:-DEV_SIZE]
        dev_X, dev_Y = X[-DEV_SIZE:], Y[-DEV_SIZE:]

        print(train_Y.swapaxes(0, 1).sum((0, 1)))
        print(test_Y.swapaxes(0, 1).sum((0, 1)))

        #train the model
        punctuator.fit([train_X], train_Y, batch_size=BATCH, epochs=EPOCH)

    punctuator.save_model(MODEL_FILE)
    punctuator.save_weights(WEIGHTS_FILE)

    #evaluate the model on the dev set (or the test set)
    for i, example in enumerate(dev_X):
        prediction = punctuator.predict(example)
        punctuator.check_result(prediction, dev_Y[i])

    #manually evaluate the model on an example
    examples = [
        "good morning chairman who I saw and members of the committee it's my pleasure to be here today I'm Elizabeth Ackles director of the office of rate payer advocates and I appreciate the chance to present on oris key activities from 2017 I have a short presentation and I'm going to move through it really quickly because you've had a long morning already and be happy to answer any questions that you have",
        "this was a measure that first was introduced back in 1979 known as the International bill of rights for women it is the first and only international instrument that comprehensively addresses women's rights within political cultural economic social and family life",
        "I'm Elizabeth Neumann from the San Francisco Department on the status of women Sita is not just about naming equal rights for women and girls it provides a framework to identify and address inequality",
        "we have monitored the demographics of commissioners and board members in San Francisco to assess the equality of political opportunities and after a decade of reports women are now half of appointees but white men are still over-represented and Asian and Latina men and women are underrepresented",
        "when the city and county faced a 300 million dollar budget deficit in 2003 a gender analysis of budget cuts by city departments identified the disproportionate effect on women and particularly women of color in the proposed layoffs and reduction of services"
    ]
    for example in examples:
        words = example.split()
        x = punctuator.create_live_data(words)
        print x
        for s in x:
            print s
            prediction = punctuator.predict(s)
            result = punctuator.add_punctuation(prediction, words)
            print(result)