def test_sync_etc_multiple_messages(mocker, run_manager): mocked_policy = mocker.MagicMock() run_manager.update_user_file_policy = mocked_policy payload = json.dumps( {"save_policy": {"glob": "*.foo", "policy": "end"}}).encode("utf8") wandb.run.socket.connection.sendall(payload + b"\0" + payload + b"\0") run_manager.test_shutdown() assert len(mocked_policy.mock_calls) == 2
def test_throttle_file_poller(mocker, run_manager): emitter = run_manager.emitter assert emitter.timeout == 1 for i in range(100): with open(os.path.join(wandb.run.dir, "file_%i.txt" % i), "w") as f: f.write(str(i)) run_manager.test_shutdown() assert emitter.timeout == 2
def test_custom_file_policy(mocker, run_manager): for i in range(5): with open(os.path.join(wandb.run.dir, "ckpt_%i.txt" % i), "w") as f: f.write(str(i)) wandb.save("ckpt*") run_manager.test_shutdown() assert isinstance( run_manager._file_event_handlers["ckpt_0.txt"], FileEventHandlerThrottledOverwriteMinWait) assert isinstance( run_manager._file_event_handlers["wandb-metadata.json"], FileEventHandlerOverwriteDeferred)
def test_file_pusher_archives_multiple(mocker, run_manager, mock_server): "Test that 100 files are batched." for i in range(10): fname = "ckpt_{}.txt".format(i) with open(fname, "w") as f: f.write("w&b" * 100) wandb.save(fname) run_manager.test_shutdown() req = [r for r in mock_server.ctx['graphql'] if 'files' in r['variables']][0] assert 'query Model' in req['query'] assert req['variables']['name'] == 'testing' assert req['variables']['files'] == ['___batch_archive_1.tgz']
def test_custom_file_policy_symlink(mocker, run_manager): mod = mocker.MagicMock() mocker.patch( 'wandb.run_manager.FileEventHandlerThrottledOverwriteMinWait.on_modified', mod) with open("ckpt_0.txt", "w") as f: f.write("joy") with open("ckpt_1.txt", "w") as f: f.write("joy" * 100) wandb.save("ckpt_0.txt") with open("ckpt_0.txt", "w") as f: f.write("joy" * 100) wandb.save("ckpt_1.txt") run_manager.test_shutdown() assert isinstance( run_manager._file_event_handlers["ckpt_0.txt"], FileEventHandlerThrottledOverwriteMinWait) assert mod.called
def test_file_pusher_doesnt_archive_if_few(mocker, run_manager, mock_server): "Test that only 3 files are uploaded individually." # Mock to increase minimum since some extra files are included with all # uploads, increasing the number past the default minimum of 6 from wandb.file_pusher import FilePusher mocker.patch.object(FilePusher, 'BATCH_THRESHOLD_SECS', 0.3) mocker.patch.object(FilePusher, 'BATCH_MIN_FILES', 10) for i in range(2): fname = "ckpt_{}.txt".format(i) with open(fname, "w") as f: f.write("w&b" * 100) wandb.save(fname) run_manager.test_shutdown() filenames = [ r['variables']['files'][0] for r in mock_server.requests['graphql'] if 'files' in r['variables'] ] # assert there is no batching assert all('.tgz' not in filename for filename in filenames)
def test_remove_auto_resume(mocker, run_manager): resume_path = os.path.join(wandb.wandb_dir(), RESUME_FNAME) with open(resume_path, "w") as f: f.write("{}") run_manager.test_shutdown() assert not os.path.exists(resume_path)