Ejemplo n.º 1
0
 def check(model, model_filename, test, lessthan, msg_prefix=""):
     lines, labels = read_labels(test["data"])
     predictions = []
     for line in lines:
         pred_label, _ = model.predict(line)
         predictions.append(pred_label)
     p1_local_out, r1_local_out = util.test(predictions, labels)
     self.assertEqual(
         len(predictions), test["n"], msg_prefix + "N: Want: " +
         str(test["n"]) + " Is: " + str(len(predictions))
     )
     self.assertTrue(
         p1_local_out >= test["p1"], msg_prefix + "p1: Want: " +
         str(test["p1"]) + " Is: " + str(p1_local_out)
     )
     self.assertTrue(
         r1_local_out >= test["r1"], msg_prefix + "r1: Want: " +
         str(test["r1"]) + " Is: " + str(r1_local_out)
     )
     path_size = get_path_size(model_filename)
     size_msg = str(test["size"]) + " Is: " + str(path_size)
     if lessthan:
         self.assertTrue(
             path_size <= test["size"],
             msg_prefix + "Size: Want at most: " + size_msg
         )
     else:
         self.assertTrue(
             path_size == test["size"],
             msg_prefix + "Size: Want: " + size_msg
         )
Ejemplo n.º 2
0
 def check(data):
     third = int(len(data) / 3)
     train_data = data[:2 * third]
     valid_data = data[third:]
     with tempfile.NamedTemporaryFile(
         delete=False
     ) as tmpf, tempfile.NamedTemporaryFile(delete=False) as tmpf2:
         for line in train_data:
             tmpf.write(
                 ("__label__" + line.strip() + "\n").encode("UTF-8")
             )
         tmpf.flush()
         for line in valid_data:
             tmpf2.write(
                 ("__label__" + line.strip() + "\n").encode("UTF-8")
             )
         tmpf2.flush()
         model = train_supervised(input=tmpf.name, **kwargs)
         true_labels = []
         all_words = []
         with open(tmpf2.name, 'r') as fid:
             for line in fid:
                 if sys.version_info < (3, 0):
                     line = line.decode("UTF-8")
                 if len(line.strip()) == 0:
                     continue
                 words, labels = model.get_line(line.strip())
                 if len(labels) == 0:
                     continue
                 all_words.append(" ".join(words))
                 true_labels += [labels]
         predictions, _ = model.predict(all_words)
         p, r = util.test(predictions, true_labels)
         N = len(predictions)
         Nt, pt, rt = model.test(tmpf2.name)
         self.assertEqual(N, Nt)
         self.assertEqual(p, pt)
         self.assertEqual(r, rt)
Ejemplo n.º 3
0
 def check(model, model_filename, test, lessthan, msg_prefix=""):
     lines, labels = read_labels(test["data"])
     predictions = []
     for line in lines:
         pred_label, _ = model.predict(line)
         predictions.append(pred_label)
     p1_local_out, r1_local_out = util.test(predictions, labels)
     self.assertEqual(
         len(predictions), test["n"], msg_prefix + "N: Want: " +
         str(test["n"]) + " Is: " + str(len(predictions)))
     self.assertTrue(
         p1_local_out >= test["p1"], msg_prefix + "p1: Want: " +
         str(test["p1"]) + " Is: " + str(p1_local_out))
     self.assertTrue(
         r1_local_out >= test["r1"], msg_prefix + "r1: Want: " +
         str(test["r1"]) + " Is: " + str(r1_local_out))
     path_size = get_path_size(model_filename)
     size_msg = str(test["size"]) + " Is: " + str(path_size)
     if lessthan:
         self.assertTrue(path_size <= test["size"],
                         msg_prefix + "Size: Want at most: " + size_msg)
     else:
         self.assertTrue(path_size == test["size"],
                         msg_prefix + "Size: Want: " + size_msg)
Ejemplo n.º 4
0
    labels = []
    with open(filename) as f:
        for line in f:
            line_labels = []
            tokens = line.split()
            for token in tokens:
                if token.startswith(prefix):
                    line_labels.append(token)
            labels.append(line_labels)
    return labels


if __name__ == "__main__":
    train_data = os.path.join(os.getenv("DATADIR", ''), 'cooking.train')
    valid_data = os.path.join(os.getenv("DATADIR", ''), 'cooking.valid')
    # train_supervised uses the same arguments and defaults as the fastText cli
    model = train_supervised(input=train_data,
                             epoch=25,
                             lr=1.0,
                             wordNgrams=2,
                             verbose=2,
                             minCount=1)
    k = 1
    predictions, _ = get_predictions(valid_data, model, k=k)
    valid_labels = get_labels_from_file(valid_data)
    p, r = test(predictions, valid_labels, k=k)
    print("N\t" + str(len(valid_labels)))
    print("P@{}\t{:.3f}".format(k, p))
    print("R@{}\t{:.3f}".format(k, r))
    model.save_model(train_data + '.bin')