Ejemplo n.º 1
0
def infer(text):
    config = conf_utils.init_test_config("20-07-15_05-08-14")
    model_class = config["model_class"]

    model = cnn.Model(config)
    if use_gpu:
        model = model.cuda()

    model.load_state_dict(
        torch.load(config["model_path"] + "cnn.pt",
                   map_location=torch.device('cpu'))["model"])

    infer_input_batch, target_vocab = data_utils.gen_infer_data(config, text)
    if args.model == "cnn" and model_class == "conv2d":
        infer_input_v = Variable(
            torch.LongTensor(np.expand_dims(infer_input_batch, 1)))
    else:
        infer_input_v = Variable(torch.LongTensor(infer_input_batch))

    if use_gpu:
        infer_input_v = infer_input_v.cuda()

    # 向前传播
    out = model(infer_input_v)
    _, pred = torch.max(out, 1)

    target = target_vocab[pred[0]]
    print(f"\n\n>> {text}: {target}")
    return target
def test(time_str):
    config = conf_utils.init_test_config(time_str)
    batch_size = config["batch_size"]

    test_input, test_target = data_utils.gen_test_data(config)
    target_vocab = data_utils.get_vocab(config["target_vocab_path"])

    print(">> build model...")
    model = cnn.Model(config)
    _, pred, _ = model.cnn()

    with tf.Session() as sess:
        saver = tf.train.Saver()
        lastest_checkpoint_name = tf.train.latest_checkpoint(
            config["model_path"])
        print(f">> last checkpoint: {lastest_checkpoint_name}")
        saver.restore(sess, lastest_checkpoint_name)

        batch_gen = batch_utils.make_batch(test_input, test_target, batch_size,
                                           False)
        input_target_list = []
        pred_target_list = []
        for batch_num in range(len(test_input) // batch_size):
            test_input_batch, test_target_batch = batch_gen.__next__()

            pred_target_arr = sess.run(pred,
                                       feed_dict={
                                           model.input_holder:
                                           test_input_batch,
                                           model.target_holder:
                                           test_target_batch
                                       })

            input_target_arr = np.argmax(test_target_batch, 1)
            input_target_list.extend(input_target_arr.tolist())
            pred_target_list.extend(pred_target_arr.tolist())

        input_target_list = [
            target_vocab[i_data] for i_data in input_target_list
        ]
        pred_target_list = [
            target_vocab[p_data] for p_data in pred_target_list
        ]
        report = metrics.classification_report(input_target_list,
                                               pred_target_list)
        print(f"\n>> REPORT:\n{report}")
        output_utils.save_metrics(config, "report.txt", report)

        cm = metrics.confusion_matrix(input_target_list, pred_target_list)
        print(f"\n>> Confusion Matrix:\n{cm}")
        output_utils.save_metrics(config, "confusion_matrix.txt", str(cm))
Ejemplo n.º 3
0
def test(config_path):
    config = conf_utils.init_test_config(config_path)

    batch_size = config["batch_size"]
    model_class = config["model_class"]
    if args.model == "cnn":
        print(f"\n>> model class is {model_class}")

    test_data, test_target = data_utils.gen_test_data(config)

    model = cnn.Model(config)
    if use_gpu:
        model = model.cuda()

    model.load_state_dict(
        torch.load(config["model_path"] + "cnn.pt",
                   map_location=torch.device('cpu'))["model"])

    test_batch_gen = batch_utils.make_batch(test_data, test_target, batch_size)

    pred_list = []
    target_list = []
    for batch_num in range(len(test_data) // batch_size):
        test_input_batch, test_target_batch = test_batch_gen.__next__()

        if args.model == "cnn" and model_class == "conv2d":
            test_input_batch_v = Variable(
                torch.LongTensor(np.expand_dims(test_input_batch, 1)))
        else:
            test_input_batch_v = Variable(torch.LongTensor(test_input_batch))

        if use_gpu:
            test_input_batch_v = test_input_batch_v.cuda()

        # 向前传播
        out = model(test_input_batch_v)
        _, pred = torch.max(out, 1)
        if use_gpu:
            pred = pred.cpu().numpy()
        pred_list.extend(pred)
        target_list.extend([np.argmax(target) for target in test_target_batch])

    report = metrics.classification_report(target_list, pred_list)
    print(f"\n>> REPORT:\n{report}")
    output_utils.save_metrics(config, "report.txt", report)

    cm = metrics.confusion_matrix(target_list, pred_list)
    print(f"\n>> Confusion Matrix:\n{cm}")
    output_utils.save_metrics(config, "confusion_matrix.txt", str(cm))
def get_server_sess(time_str):
    config = conf_utils.init_test_config(time_str)
    input_vocab = data_utils.get_vocab(config["input_vocab_path"])
    target_vocab = data_utils.get_vocab(config["target_vocab_path"])

    print(">> build model...")
    model = cnn.Model(config)
    _, pred, _ = model.cnn()

    sess = tf.InteractiveSession()
    saver = tf.train.Saver()
    lastest_checkpoint_name = tf.train.latest_checkpoint(config["model_path"])
    print(f">> last checkpoint: {lastest_checkpoint_name}")
    saver.restore(sess, lastest_checkpoint_name)
    return sess, pred, target_vocab, input_vocab, model
Ejemplo n.º 5
0
def test(time_str):
    config = conf_utils.init_test_config(time_str)

    test_input, test_target = data_utils.gen_test_data(config)

    vocab = joblib.load(config["model_path"] + "vocab.json")

    model = bayes.Model(vocab)
    test_term_doc = model.get_tfidf(test_input, "test")

    nb = joblib.load(config["model_path"] + "svm.m")

    test_preds = nb.predict(test_term_doc)

    report = metrics.classification_report(test_target, test_preds)
    print(f"\n>> REPORT:\n{report}")
    output_utils.save_metrics(config, "report.txt", report)

    cm = metrics.confusion_matrix(test_target, test_preds)
    print(f"\n>> Confusion Matrix:\n{cm}")
    output_utils.save_metrics(config, "confusion_matrix.txt", str(cm))
Ejemplo n.º 6
0
def train_v2(config_path):
    config = conf_utils.init_train_config(config_path)
    train_input, train_target, validate_input, validate_target = data_utils.gen_train_data(
        config)

    print(">> build model...")
    if args.model == "cnn":
        model = cnn2.Model(config).build()
    # elif args.model == "rnn":
    #     model = rnn.Model(config).build()
    else:
        raise Exception(f"error model: {args.model}")

    model.summary()

    model.fit(train_input,
              train_target,
              batch_size=64,
              epochs=20,
              validation_data=(validate_input, validate_target))

    config = conf_utils.init_test_config(config["time_now"])
    test_input, test_target = data_utils.gen_test_data(config)
    model.evaluate(test_input, test_target, batch_size=64)