コード例 #1
0
ファイル: test_bisemantic.py プロジェクト: wpm/bisemantic
 def test_load_data_with_null_and_index_column(self):
     actual = data_file("test/resources/data_with_null_values.csv",
                        index="id")
     self.assertIsInstance(actual, pd.DataFrame)
     assert_array_equal(["text1", "text2", "label"], actual.columns)
     self.assertEqual(3, len(actual))
     assert_array_equal([1, 3, 5], actual.index)
コード例 #2
0
ファイル: console.py プロジェクト: wpm/bisemantic
def train_or_continue(args, training_operation):
    from bisemantic.data import cross_validation_partitions

    training = data_file(args.training, args.n, args.index_name,
                         args.text_1_name, args.text_2_name, args.label_name,
                         args.invalid_labels, not args.not_comma_delimited)
    if args.validation_fraction is not None:
        training, validation = cross_validation_partitions(
            training, 1 - args.validation_fraction, 1)[0]
    elif args.validation_set is not None:
        validation = data_file(args.validation_set, args.n, args.index_name,
                               args.text_1_name, args.text_2_name,
                               args.label_name, args.invalid_labels,
                               not args.not_comma_delimited)
    else:
        validation = None

    _, training_history = training_operation(training, validation)
    print(training_history.latest_run_summary())
コード例 #3
0
ファイル: console.py プロジェクト: wpm/bisemantic
def score(args):
    from bisemantic.classifier import TextPairClassifier

    test = data_file(args.test, args.n, args.index_name, args.text_1_name,
                     args.text_2_name, args.label_name, args.invalid_labels,
                     not args.not_comma_delimited)
    logger.info("Score predictions for %d pairs" % len(test))
    model = TextPairClassifier.load_from_model_directory(
        args.model_directory_name)
    scores = model.score(test, batch_size=args.batch_size)
    print(", ".join("%s=%0.5f" % s for s in scores))
コード例 #4
0
ファイル: console.py プロジェクト: wpm/bisemantic
def create_cross_validation_partitions(args):
    from bisemantic.data import cross_validation_partitions
    data = data_file(args.data, args.n, args.index_name, args.text_1_name,
                     args.text_2_name, args.label_name, args.invalid_labels,
                     not args.not_comma_delimited)
    for i, (train_partition, validate_partition) in enumerate(
            cross_validation_partitions(data, args.fraction, args.k)):
        train_name, validate_name = [
            os.path.join(args.output_directory,
                         "%s.%d.%s.csv" % (args.prefix, i + 1, name))
            for name in ["train", "validate"]
        ]
        train_partition.to_csv(train_name)
        validate_partition.to_csv(validate_name)
コード例 #5
0
ファイル: console.py プロジェクト: wpm/bisemantic
def predict(args):
    from bisemantic.classifier import TextPairClassifier

    test = data_file(args.test, args.n, args.index_name, args.text_1_name,
                     args.text_2_name, args.label_name, args.invalid_labels,
                     not args.not_comma_delimited)
    logger.info("Predict labels for %d pairs" % len(test))
    model = TextPairClassifier.load_from_model_directory(
        args.model_directory_name)
    class_names = TextPairClassifier.class_names_from_model_directory(
        args.model_directory_name)
    predictions = model.predict(test,
                                batch_size=args.batch_size,
                                class_names=class_names)
    print(predictions.to_csv())
コード例 #6
0
ファイル: test_bisemantic.py プロジェクト: wpm/bisemantic
 def test_load_snli_format(self):
     snli = data_file("test/resources/snli.csv",
                      index="pairID",
                      text_1_name="sentence1",
                      text_2_name="sentence2",
                      label_name="gold_label",
                      invalid_labels=["-"],
                      comma_delimited=False)
     assert_array_equal(["text1", "text2", "label"], snli.columns)
     # There are 9 samples because one is dropped because it has "-" as a gold_label.
     self.assertEqual(9, len(snli))
     # The pairID index is names of JPEGs.
     self.assertTrue(snli.index.str.match(r"\d{10}\.jpg#\dr\d[nec]").all())
     self.assertEqual({'contradiction', 'neutral', 'entailment'},
                      set(snli.label.values))