def test_upload(): # Make dirs os.makedirs('tmp/publish_dir', exist_ok=True) populate_wham_dir('tmp/wham') # Dataset and NN train_set = WhamDataset('tmp/wham', task='sep_clean') model = ConvTasNet(n_src=2, n_repeats=2, n_blocks=2, bn_chan=16, hid_chan=4, skip_chan=8, n_filters=32) # Save publishable model_conf = model.serialize() model_conf.update(train_set.get_infos()) save_publishable('tmp/publish_dir', model_conf, metrics={}, train_conf={}) if False: # Upload zen, current = upload_publishable( 'tmp/publish_dir', uploader="Manuel Pariente", affiliation="INRIA", use_sandbox=True, unit_test=True, # Remove this argument and monkeypatch `input()` ) # Assert metadata is correct meta = current.json()['metadata'] assert meta['creators'][0]['name'] == "Manuel Pariente" assert meta['creators'][0]['affiliation'] == "INRIA" assert 'asteroid-models' in [d['identifier'] for d in meta['communities']] # Clean up zen.remove_deposition(current.json()['id']) shutil.rmtree('tmp/wham')
def test_save_and_load_convtasnet(fb): model1 = ConvTasNet(n_src=2, n_repeats=2, n_blocks=2, bn_chan=16, hid_chan=4, skip_chan=8, n_filters=32, fb_name=fb) test_input = torch.randn(1, 800) model_conf = model1.serialize() reconstructed_model = ConvTasNet.from_pretrained(model_conf) assert_allclose(model1.separate(test_input), reconstructed_model(test_input))
def test_upload(): # Make dirs os.makedirs("tmp/publish_dir", exist_ok=True) populate_wham_dir("tmp/wham") # Dataset and NN train_set = WhamDataset("tmp/wham", task="sep_clean") model = ConvTasNet(n_src=2, n_repeats=2, n_blocks=2, bn_chan=16, hid_chan=4, skip_chan=8, n_filters=32) # Save publishable model_conf = model.serialize() model_conf.update(train_set.get_infos()) save_publishable("tmp/publish_dir", model_conf, metrics={}, train_conf={}) # Upload token = os.getenv("ACCESS_TOKEN") if token: # ACESS_TOKEN is not available on forks. zen, current = upload_publishable( "tmp/publish_dir", uploader="Manuel Pariente", affiliation="INRIA", use_sandbox=True, unit_test=True, # Remove this argument and monkeypatch `input()` git_username="******", ) # Assert metadata is correct meta = current.json()["metadata"] assert meta["creators"][0]["name"] == "Manuel Pariente" assert meta["creators"][0]["affiliation"] == "INRIA" assert "asteroid-models" in [ d["identifier"] for d in meta["communities"] ] # Clean up zen.remove_deposition(current.json()["id"]) shutil.rmtree("tmp/wham")