Esempio n. 1
0
def test_upload():
    # Make dirs
    os.makedirs('tmp/publish_dir', exist_ok=True)
    populate_wham_dir('tmp/wham')

    # Dataset and NN
    train_set = WhamDataset('tmp/wham', task='sep_clean')
    model = ConvTasNet(n_src=2, n_repeats=2, n_blocks=2, bn_chan=16,
                       hid_chan=4, skip_chan=8, n_filters=32)

    # Save publishable
    model_conf = model.serialize()
    model_conf.update(train_set.get_infos())
    save_publishable('tmp/publish_dir', model_conf, metrics={}, train_conf={})

    if False:
        # Upload
        zen, current = upload_publishable(
            'tmp/publish_dir',
            uploader="Manuel Pariente",
            affiliation="INRIA",
            use_sandbox=True,
            unit_test=True,  # Remove this argument and monkeypatch `input()`
        )

        # Assert metadata is correct
        meta = current.json()['metadata']
        assert meta['creators'][0]['name'] == "Manuel Pariente"
        assert meta['creators'][0]['affiliation'] == "INRIA"
        assert 'asteroid-models' in [d['identifier'] for d in meta['communities']]

        # Clean up
        zen.remove_deposition(current.json()['id'])
    shutil.rmtree('tmp/wham')
Esempio n. 2
0
def test_save_and_load_convtasnet(fb):
    model1 = ConvTasNet(n_src=2,
                        n_repeats=2,
                        n_blocks=2,
                        bn_chan=16,
                        hid_chan=4,
                        skip_chan=8,
                        n_filters=32,
                        fb_name=fb)
    test_input = torch.randn(1, 800)
    model_conf = model1.serialize()

    reconstructed_model = ConvTasNet.from_pretrained(model_conf)
    assert_allclose(model1.separate(test_input),
                    reconstructed_model(test_input))
Esempio n. 3
0
def test_upload():
    # Make dirs
    os.makedirs("tmp/publish_dir", exist_ok=True)
    populate_wham_dir("tmp/wham")

    # Dataset and NN
    train_set = WhamDataset("tmp/wham", task="sep_clean")
    model = ConvTasNet(n_src=2,
                       n_repeats=2,
                       n_blocks=2,
                       bn_chan=16,
                       hid_chan=4,
                       skip_chan=8,
                       n_filters=32)

    # Save publishable
    model_conf = model.serialize()
    model_conf.update(train_set.get_infos())
    save_publishable("tmp/publish_dir", model_conf, metrics={}, train_conf={})

    # Upload
    token = os.getenv("ACCESS_TOKEN")
    if token:  # ACESS_TOKEN is not available on forks.
        zen, current = upload_publishable(
            "tmp/publish_dir",
            uploader="Manuel Pariente",
            affiliation="INRIA",
            use_sandbox=True,
            unit_test=True,  # Remove this argument and monkeypatch `input()`
            git_username="******",
        )

        # Assert metadata is correct
        meta = current.json()["metadata"]
        assert meta["creators"][0]["name"] == "Manuel Pariente"
        assert meta["creators"][0]["affiliation"] == "INRIA"
        assert "asteroid-models" in [
            d["identifier"] for d in meta["communities"]
        ]

        # Clean up
        zen.remove_deposition(current.json()["id"])
        shutil.rmtree("tmp/wham")