コード例 #1
0
def _main():
    parsed_args, hparams = parse_args_and_hparams()

    if hparams.task.lower() not in ["clinc", "hwu", "banking"]:
        raise ValueError(f"{hparams.task} is not a valid task")

    encoder_client = get_encoder_client(hparams.encoder_type,
                                        cache_dir=hparams.cache_dir)

    categories, encodings, labels = _preprocess_data(encoder_client, hparams,
                                                     parsed_args.data_dir)

    accs = []
    eval_acc_histories = []
    if hparams.eval_each_epoch:
        validation_data = (encodings[_TEST], labels[_TEST])
        verbose = 1
    else:
        validation_data = None
        verbose = 0

    for seed in range(hparams.seeds):
        glog.info(f"### Seed {seed} ###")
        model, eval_acc_history = train_model(encodings[_TRAIN],
                                              labels[_TRAIN],
                                              categories,
                                              hparams,
                                              validation_data=validation_data,
                                              verbose=verbose)

        _, acc = model.evaluate(encodings[_TEST], labels[_TEST], verbose=0)
        glog.info(f"Seed accuracy: {acc:.3f}")
        accs.append(acc)
        eval_acc_histories.append(eval_acc_history)

    average_acc = np.mean(accs)
    variance = np.std(accs)
    glog.info(f"Average results:\n"
              f"Accuracy: {average_acc:.3f}\n"
              f"Variance: {variance:.3f}")

    results = {
        "Average results": {
            "Accuracy": float(average_acc),
            "Variance": float(variance)
        }
    }
    if hparams.eval_each_epoch:
        results["Results per epoch"] = [[float(x) for x in y]
                                        for y in eval_acc_histories]

    if not tf.gfile.Exists(parsed_args.output_dir):
        tf.gfile.MakeDirs(parsed_args.output_dir)
    with tf.gfile.Open(os.path.join(parsed_args.output_dir, "results.json"),
                       "w") as f:
        json.dump(results, f, indent=2)
コード例 #2
0
    def test_training_validation(self):
        training_examples = np.array([[1, 2, 3], [3, 2, 1],
                                      [4, 5, 6], [6, 5, 4]])
        training_labels = np.array([0, 1, 0, 1])
        label_set = {0, 1}

        model, acc_hist = classifier.train_model(
            train_encodings=training_examples,
            train_labels=training_labels,
            categories=label_set,
            hparams=TrainModelTest.test_hparams,
            validation_data=(
                np.array([[10, 20, 30], [30, 20, 10]]), np.array([0, 1])
            )
        )

        self.assertEqual(len(acc_hist), 2)
コード例 #3
0
    def test_training_no_validation(self):
        training_examples = np.array([[1, 2, 3], [3, 2, 1],
                                      [4, 5, 6], [6, 5, 4]])
        training_labels = np.array([0, 1, 0, 1])
        label_set = {0, 1}

        model, acc_hist = classifier.train_model(
            train_encodings=training_examples,
            train_labels=training_labels,
            categories=label_set,
            hparams=TrainModelTest.test_hparams
        )

        self.assertIsNone(acc_hist)
        self.assertIsInstance(model, tf.keras.models.Sequential)

        pred = model(np.array([[10.5, 20, 30]]))
        self.assertEqual(pred.shape, (1, 2))
コード例 #4
0
def _main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_file',
                        required=True,
                        help="The absolute path to the training data")
    parser.add_argument(
        "--params",
        default="intent_detection.config.default",
        help="Path to config file containing the model hyperparameters")
    parser.add_argument(
        "--params_overrides",
        help="A comma-separated list of param=value pairs specifying overrides"
        " to the model hyperparameters.",
        default="")

    parsed_args = parser.parse_args()
    hparams = _object_from_name(parsed_args.params)
    hparams.parse(parsed_args.params_overrides)

    encoder_client = get_encoder_client(hparams.encoder_type,
                                        cache_dir=hparams.cache_dir)
    categories, encodings, labels = _preprocess_data(encoder_client,
                                                     parsed_args.train_file)

    print(f"Your labels are here, make sure they are correct: {categories}")

    model, _ = train_model(encodings, labels, categories, hparams)

    print("Now you will be able to speak to the model. Press Ctrl + C to quit")
    while True:
        query = input("Your query:")
        query_encoding = encoder_client.encode_sentences([query])
        output = model.predict(query_encoding).flatten()
        prediction = np.argmax(output)
        print(f"Prediction: {categories[prediction]}, "
              f"score: {output[prediction]}")
コード例 #5
0
def _main():
    parsed_args, hparams = parse_args_and_hparams()

    hparams.data_regime = over_ride_sample_no

    if hparams.task.lower() not in [
            "clinc", "hwu", "banking", "wallet", "alliance", "bank_split"
    ]:
        raise ValueError(f"{hparams.task} is not a valid task")
    hparams.task = over_ride_dataset
    hparams.encoder_type = algo

    encoder_client = get_encoder_client(hparams.encoder_type,
                                        cache_dir=hparams.cache_dir)

    categories, encodings, labels, train, test = _preprocess_data(
        encoder_client, hparams, parsed_args.data_dir)

    accs = []
    eval_acc_histories = []
    if hparams.eval_each_epoch:
        validation_data = (encodings[_TEST], labels[_TEST])
        verbose = 1
    else:
        validation_data = None
        verbose = 0

    for seed in range(hparams.seeds):
        glog.info(f"### Seed {seed} ###")

        if algo == 'rf_tfidf':
            acc = rf_tfidf(train, test)
            return over_ride_sample_no, acc
        elif algo == 'sbert_cosine':
            acc = sbert_cosine(train, test)
            return over_ride_sample_no, acc
        else:
            model, eval_acc_history = train_model(
                encodings[_TRAIN],
                labels[_TRAIN],
                categories,
                hparams,
                validation_data=validation_data,
                verbose=verbose)
            _, acc = model.evaluate(encodings[_TEST], labels[_TEST], verbose=0)


#         print(_, 'loss of evaluation')
#         print('PREDICT')
#         pred = model.predict(encodings[_TEST], labels[_TEST])

#         t = []
#         for r in [5]:
#             topk = pred.argsort(axis=1)[:,-r:]#[::-1]
#             t.append(sum([1 if i in topk[c] else 0 for c,i in enumerate(labels[_TEST])])/labels[_TEST].shape[0])
#         print('p@1',acc)
# #         print('p@5',t[0])
#         print('memory:',memory_pred)

        glog.info(f"Seed accuracy: {acc:.3f}")
        accs.append(acc)
        eval_acc_histories.append(eval_acc_history)

    average_acc = np.mean(accs)
    variance = np.std(accs)
    glog.info(f"Average results:\n"
              f"Accuracy {over_ride_sample_no}: {average_acc:.3f}\n"
              f"Variance: {variance:.3f}")

    results = {
        "Average results": {
            "Accuracy": float(average_acc),
            "Variance": float(variance)
        }
    }
    if hparams.eval_each_epoch:
        results["Results per epoch"] = [[float(x) for x in y]
                                        for y in eval_acc_histories]

    if not tf.gfile.Exists(parsed_args.output_dir):
        tf.gfile.MakeDirs(parsed_args.output_dir)
    with tf.gfile.Open(os.path.join(parsed_args.output_dir, "results.json"),
                       "w") as f:
        json.dump(results, f, indent=2)

    memory_pred = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
    print('MEMORY_PRED:', memory_pred)

    start_time = datetime.now()
    query_encoding = encoder_client.encode_sentences(['test'])
    output = model.predict(query_encoding)
    prediction = np.argmax(output)
    end_time = datetime.now()
    delta = end_time - start_time
    print('TIME DELTA:', delta)
    return (over_ride_sample_no, acc, memory_pred, delta, hparams.task)