def scoped_init(): clt = TrackClient(f'file://{file}') clt.set_project(name='project_name') clt.set_group(name='group_name') clt.new_trial() clt.log_arguments(batch_size=256) clt.save() return clt.trial.uid
def test_trial(): client = TrackClient('file:test.json') client.set_project(name='ConvnetTest', description='Trail test example') client.set_group(name='test_group') logger1 = client.new_trial() client.get_arguments({'a': 1}) uid1 = logger1.trial.hash logger2 = client.new_trial(force=True) client.get_arguments({'a': 2}) uid2 = logger2.trial.hash assert uid1 != uid2, 'Trials with different parameters must have different hash'
def test_client_log_api(): with Remove('client_test.json'): client = TrackClient('file://client_test.json') client.set_project(name='test_client') trial = client.new_trial() trial.log_arguments(batch_size=256) trial.log_metrics(step=1, epoch_loss=1) trial.log_metrics(accuracy=0.98) client.save() client.report()
def test_client_no_group(file='client_2'): with Remove(file): client = TrackClient(f'file://{file}') client.set_project(name='test_client') log = client.new_trial() client.log_arguments(batch_size=256) client.log_metrics(step=1, epoch_loss=1) client.log_metrics(accuracy=0.98) client.save() client.report() print(log.trial.metrics)
def test_client_capture_output(file='client_output'): with Remove(file): client = TrackClient(f'file://{file}') client.set_project(name='project_name') logger = client.new_trial() logger.capture_output(50) for i in range(0, 100): print(f'testing_output_{i}') out = logger.stdout.raw() for i in range(0, 25): assert out[i * 2] == f'testing_output_{100 - 25 + i}' assert out[i * 2 + 1] == '\n'