コード例 #1
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_save_and_load_weights(model_save_path):
    arch = resnet18(2, use_pretrained=False)
    model = cnn.CNN("resnet18", classes=["a", "b"], sample_duration=5.0)
    model.save_weights(model_save_path)
    model1 = cnn.CNN(arch, classes=["a", "b"], sample_duration=5.0)
    model1.load_weights(model_save_path)
    assert np.array_equal(
        model.network.state_dict()["conv1.weight"].numpy(),
        model1.network.state_dict()["conv1.weight"].numpy(),
    )
コード例 #2
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_single_target_prediction(test_df):
    model = cnn.CNN("resnet18", classes=[0, 1], sample_duration=5.0)
    model.single_target = True
    scores, preds, _ = model.predict(test_df, binary_preds="single_target")

    assert len(scores) == 2
    assert len(preds) == 2
コード例 #3
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_predict_splitting_short_file(short_file_df):
    model = cnn.CNN("resnet18", classes=[0, 1], sample_duration=5.0)
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        scores, _, _ = model.predict(short_file_df)
        assert len(scores) == 0
        assert "prediction_dataset" in str(w[0].message)
コード例 #4
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_split_resnet_feat_clf(train_df):
    model = cnn.CNN("resnet18", classes=[0, 1], sample_duration=2)
    cnn.separate_resnet_feat_clf(model)
    assert "feature" in model.optimizer_params
    model.optimizer_params["feature"]["lr"] = 0.1
    model.train(train_df, epochs=0, save_path="tests/models")
    shutil.rmtree("tests/models/")
コード例 #5
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_predict_without_splitting(test_df):
    model = cnn.CNN("resnet18", classes=[0, 1], sample_duration=5.0)
    scores, preds, _ = model.predict(test_df,
                                     split_files_into_clips=False,
                                     binary_preds="multi_target",
                                     threshold=0)
    assert len(scores) == len(test_df)
    assert len(preds) == len(test_df)
コード例 #6
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_multi_target_prediction(train_df, test_df):
    model = cnn.CNN("resnet18", classes=[0, 1], sample_duration=5.0)
    scores, preds, _ = model.predict(test_df,
                                     binary_preds="multi_target",
                                     threshold=0.1)

    assert len(scores) == 2
    assert len(preds) == 2
コード例 #7
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_predict_wrong_input_error(test_df):
    """cannot pass a preprocessor or dataset to predict. only file paths as list or df"""
    model = cnn.CNN("resnet18", classes=[0, 1], sample_duration=5.0)
    pre = SpectrogramPreprocessor(2.0)
    with pytest.raises(AssertionError):
        model.predict(pre)
    with pytest.raises(AssertionError):
        ds = AudioFileDataset(test_df, pre)
        model.predict(ds)
コード例 #8
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_prediction_overlap(test_df):
    model = cnn.CNN("resnet18", classes=[0, 1], sample_duration=5.0)
    model.single_target = True
    scores, preds, _ = model.predict(test_df,
                                     binary_preds="single_target",
                                     overlap_fraction=0.5)

    assert len(scores) == 3
    assert len(preds) == 3
コード例 #9
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_prediction_warns_different_classes(train_df):
    model = cnn.CNN("resnet18", classes=["a", "b"], sample_duration=5.0)
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        # raises warning bc test_df columns != model.classes
        model.predict(train_df)
        all_warnings = ""
        for wi in w:
            all_warnings += str(wi.message)
        assert "classes" in all_warnings
コード例 #10
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_train_multi_target(train_df):
    model = cnn.CNN("resnet18", classes=[0, 1], sample_duration=5.0)
    model.train(
        train_df,
        train_df,
        save_path="tests/models",
        epochs=1,
        batch_size=2,
        save_interval=10,
        num_workers=0,
    )
    shutil.rmtree("tests/models/")
コード例 #11
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_save_and_load_model(model_save_path):
    arch = alexnet(2, use_pretrained=False)
    classes = [0, 1]

    cnn.CNN(arch, classes, 1.0).save(model_save_path)
    m = cnn.load_model(model_save_path)
    assert m.classes == classes
    assert type(m) == cnn.CNN

    cnn.InceptionV3(classes, 1.0, use_pretrained=False).save(model_save_path)
    m = cnn.load_model(model_save_path)
    assert m.classes == classes
    assert type(m) == cnn.InceptionV3
コード例 #12
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_train_one_class(train_df):
    model = cnn.CNN("resnet18", classes=[0], sample_duration=5.0)
    model.single_target = True
    model.train(
        train_df[[0]],
        train_df[[0]],
        save_path="tests/models",
        epochs=1,
        batch_size=2,
        save_interval=10,
        num_workers=0,
    )
    shutil.rmtree("tests/models/")
コード例 #13
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_train_predict_architecture(train_df):
    """test passing a specific architecture to PytorchModel"""
    arch = alexnet(2, use_pretrained=False)
    model = cnn.CNN(arch, [0, 1], sample_duration=2)
    model.train(
        train_df,
        train_df,
        save_path="tests/models/",
        epochs=1,
        batch_size=2,
        save_interval=10,
        num_workers=0,
    )
    model.predict(train_df, num_workers=0)
    shutil.rmtree("tests/models/")
コード例 #14
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_init_with_str():
    model = cnn.CNN("resnet18", classes=[0, 1], sample_duration=5.0)
コード例 #15
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_train_no_validation(train_df):
    model = cnn.CNN("resnet18", classes=[0, 1], sample_duration=2)
    model.train(train_df, save_path="tests/models")
    shutil.rmtree("tests/models/")
コード例 #16
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_predict_missing_file_is_unsafe_sample(missing_file_df):
    model = cnn.CNN("resnet18", classes=[0, 1], sample_duration=5.0)
    scores, _, unsafe_samples = model.predict(missing_file_df, threshold=0.1)

    assert len(scores) == 0
    assert len(unsafe_samples) == 1
コード例 #17
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_eval(train_df):
    model = cnn.CNN("resnet18", classes=[0, 1], sample_duration=2)
    scores, _, _ = model.predict(train_df, split_files_into_clips=False)
    model.eval(train_df.values, scores.values)
コード例 #18
0
ファイル: test_cnn.py プロジェクト: kitzeslab/opensoundscape
def test_prediction_returns_consistent_values(train_df):
    model = cnn.CNN("resnet18", classes=["a", "b"], sample_duration=5.0)
    a, _, _ = model.predict(train_df)
    b, _, _ = model.predict(train_df)
    assert np.allclose(a.values, b.values, 1e-6)