def test_client_set_trial_throw(file='client_throw'): try: with Remove(file): client = TrackClient(f'file://{file}') client.set_trial(uid='does_not_exist') except TrialDoesNotExist: pass
def test_client_capture_output(file='client_output'): with Remove(file): client = TrackClient(f'file://{file}') client.set_project(name='project_name') logger = client.new_trial() logger.capture_output(50) for i in range(0, 100): print(f'testing_output_{i}') out = logger.stdout.raw() for i in range(0, 25): assert out[i * 2] == f'testing_output_{100 - 25 + i}' assert out[i * 2 + 1] == '\n'
def test_client_orion_integration(file='client_orion_integration'): def scoped_init(): clt = TrackClient(f'file://{file}') clt.set_project(name='project_name') clt.set_group(name='group_name') clt.new_trial() clt.log_arguments(batch_size=256) clt.save() return clt.trial.uid old_environ = os.environ.copy() try: with Remove(file): trial_id = scoped_init() os.environ['ORION_PROJECT'] = 'project_name' os.environ['ORION_EXPERIMENT'] = 'group_name' os.environ['ORION_TRIAL_ID'] = trial_id client = TrackClient(f'file://{file}') assert client.trial.parameters['batch_size'] == 256 finally: os.environ = old_environ
def test_client_log_api(): with Remove('client_test.json'): client = TrackClient('file://client_test.json') client.set_project(name='test_client') trial = client.new_trial() trial.log_arguments(batch_size=256) trial.log_metrics(step=1, epoch_loss=1) trial.log_metrics(accuracy=0.98) client.save() client.report()
def _initialize_client(self, uri): self.uri = uri self.options = parse_uri(uri)["query"] self.client = TrackClient(uri) self.backend = self.client.protocol self.project = None self.group = None self.objective = self.options.get("objective") assert self.objective is not None, "An objective should be defined!"
def __init__(self, uri): if not HAS_TRACK: # We ignored the import error above in case we did not need track # but now that we do we can rethrow it raise ImportError("Track is not installed!") self.uri = uri self.options = parse_uri(uri)["query"] self.client = TrackClient(uri) self.backend = self.client.protocol self.project = None self.group = None self.objective = self.options.get("objective") self.lies = dict() assert self.objective is not None, "An objective should be defined!"
def test_trial(file='test.json'): with Remove(file): client = TrackClient(f'file:{file}') client.set_project(name='ConvnetTest', description='Trail test example') client.set_group(name='test_group') logger1 = client.new_trial() client.get_arguments({'a': 1}) uid1 = logger1.trial.hash logger2 = client.new_trial(force=True) client.get_arguments({'a': 2}) uid2 = logger2.trial.hash assert uid1 != uid2, 'Trials with different parameters must have different hash'
def scoped_init(): clt = TrackClient(f'file://{file}') clt.set_project(name='project_name') clt.set_group(name='group_name') clt.new_trial() clt.log_arguments(batch_size=256) clt.save() return clt.trial.uid
def test_client_no_group(file='client_2'): with Remove(file): client = TrackClient(f'file://{file}') client.set_project(name='test_client') log = client.new_trial() client.log_arguments(batch_size=256) client.log_metrics(step=1, epoch_loss=1) client.log_metrics(accuracy=0.98) client.save() client.report() print(log.trial.metrics)