示例#1
0
def main():
    args = sys.argv

    batch_size = 128
    epochs = 100
    maxlen = 300
    model_path = 'models/cnn_model.h5'
    num_words = 40000
    num_label = 2

    x, y = load_dataset('data/amazon_reviews_multilingual_JP_v1_00.tsv')

    x = preprocess_dataset(x)
    x_train, x_test, y_train, y_test = train_test_split(x,
                                                        y,
                                                        test_size=0.2,
                                                        random_state=42)

    vocab = build_vocabulary(x_train, num_words)
    x_train = vocab.texts_to_sequences(x_train)
    x_test = vocab.texts_to_sequences(x_test)
    x_train = pad_sequences(x_train, maxlen=maxlen, truncating='post')
    x_test = pad_sequences(x_test, maxlen=maxlen, truncating='post')

    emb_flg = args[0]
    if emb_flg == 't':
        wv = load_fasttext('../chap08/models/cc.ja.300.vec.gz')
        wv = filter_embeddings(wv, vocab.word_index, num_words)
    else:
        wv = None

    model = CNNModel(num_words, num_label, embeddings=wv).build()
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['acc'])

    callbacks = [
        EarlyStopping(patience=3),
        ModelCheckpoint(model_path, save_best_only=True)
    ]

    model.fit(x=x_train,
              y=y_train,
              batch_size=batch_size,
              epochs=epochs,
              validation_split=0.2,
              callbacks=callbacks,
              shuffle=True)

    model = load_model(model_path)
    api = InferenceAPI(model, vocab, preprocess_dataset)
    y_pred = api.predict_from_sequences(x_test)

    print('precision: {:.4f}'.format(
        precision_score(y_test, y_pred, average='binary')))
    print('recall   : {:.4f}'.format(
        recall_score(y_test, y_pred, average='binary')))
    print('f1   : {:.4f}'.format(f1_score(y_test, y_pred, average='binary')))
def train():
    pos_data_path = '../dataset/weibo60000/pos60000_utf8.txt_updated'
    pos_x, pos_y = read_pos_data(pos_data_path)
    print(len(pos_x))
    print(len(pos_y))
    # print(pos_y)

    neg_data_path = '../dataset/weibo60000/neg60000_utf8.txt_updated'
    neg_x, neg_y = read_neg_data(neg_data_path)
    print(len(neg_x))
    print(len(neg_y))
    # print(neg_y)

    train_pos_x = pos_x[:41025]
    train_pos_y = pos_y[:41025]
    val_pos_x = pos_x[41025:52746]
    val_pos_y = pos_y[41025:52746]
    test_pos_x = pos_x[52746:]
    test_pos_y = pos_y[52746:]

    train_neg_x = neg_x[:41165]
    train_neg_y = neg_y[:41165]
    val_neg_x = neg_x[41165:52926]
    val_neg_y = neg_y[41165:52926]
    test_neg_x = neg_x[52926:]
    test_neg_y = neg_y[52926:]

    train_x, train_y = concate_data(train_pos_x, train_pos_y, train_neg_x,
                                    train_neg_y)
    val_x, val_y = concate_data(val_pos_x, val_pos_y, val_neg_x, val_neg_y)
    test_x, test_y = concate_data(test_pos_x, test_pos_y, test_neg_x,
                                  test_neg_y)

    print('The number of train-set:', len(train_x))
    # print(len(train_y))
    print('The number of val-set:', len(val_x))
    # print(len(val_y))
    print('The number of test-set:', len(test_x))
    # print(len(test_y))

    embedding = BERTEmbedding('../dataset/chinese_L-12_H-768_A-12',
                              sequence_length=100)
    print('embedding_size', embedding.embedding_size)
    # print(embedding.model.output

    model = CNNModel(embedding)
    model.fit(train_x,
              train_y,
              val_x,
              val_y,
              batch_size=128,
              epochs=20,
              fit_kwargs={'callbacks': [tf_board_callback]})
    model.evaluate(test_x, test_y)
    model.save('./model/cnn_bert_model')
示例#3
0
        agent='dqn',
        states=env.states(),
        actions=env.actions(),
        batch_size=1,
        learning_rate=1e-3,
        memory=10000,
        exploration=0.2,
    )

    state = env.reset()
    # Train for 200 episodes
    for _ in range(20):
        print(f'[EPISODE {_}] started...')
        state = env.reset()
        terminal = False
        while True:
            print(f'state: {state}')
            print(f'state: {state.shape}')
            actions = agent.act(states=state)
            state, terminal, reward = env.execute(actions=actions)
            agent.observe(terminal=terminal, reward=reward)
            if terminal:
                idxs = env.query_indicies
                print(f'query indicies: {idxs}')
                dm.label_samples(idxs, y_oracle[idxs])
                y_oracle = np.delete(y_oracle, idxs, axis=0)
                print(dm.train)
                cnn_model.fit(*dm.train.get_xy())
                cnn_model.print_evaluation(*dm.test.get_xy())
                break