def e2e_socketed(client=1, security_layer=None): remove('socketed.json') from track.persistence.socketed import start_track_server from multiprocessing import Process import socket s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(('', 0)) port = s.getsockname()[1] s.close() # (protocol, hostname, port) db = Process(target=start_track_server, args=('file://socketed.json', 'localhost', port, security_layer)) db.start() time.sleep(1) security = '' if security_layer is not None: security = f'?security_layer={security_layer}' try: uri = [f'socket://localhost:{port}' + security] * client multi_client_launch(uri, client) except Exception as e: db.terminate() raise e finally: db.terminate() remove('socketed.json')
def test_orion_poc(backend='track:file://orion_results.json?objective=epoch_loss', max_trials=2): remove('orion_results.json') os.environ['ORION_STORAGE'] = backend _, uri = os.environ.get('ORION_STORAGE').split(':', maxsplit=1) cwd = os.getcwd() os.chdir(os.path.dirname(__file__)) multiple_of_8 = [8 * i for i in range(32 // 8, 512 // 8)] orion.core.cli.main([ '-vv', '--debug', 'hunt', '--config', 'orion.yaml', '-n', 'random', #'--metric', 'error_rate', '--max-trials', str(max_trials), './end_to_end.py', f'--batch-size~choices({multiple_of_8})', '--backend', uri ]) os.chdir(cwd) remove('orion_results.json')
def test_local_parallel(woker_count=20): """Here we check that _update_count is atomic and cannot run out of sync. `count` and the other can because it does not happen inside the lock (first fetch then increment """ global trial_hash, trial_rev # -- Create the object that are going to be accessed in parallel remove('test_parallel.json') backend = make_local('file://test_parallel.json') project_def = Project(name='test') project = backend.new_project(project_def) group_def = TrialGroup(name='test_group', project_id=project.uid) group = backend.new_trial_group(group_def) trial = backend.new_trial( Trial( parameters={'batch': 256}, project_id=project.uid, group_id=group.uid) ) count = trial.metadata.get('count', 0) backend.log_trial_metadata(trial, count=count) trial_hash, trial_rev = trial.uid.split('_') # -- Setup done processes = [Process(target=increment) for _ in range(0, woker_count)] print('-- Start') [p.start() for p in processes] [p.join() for p in processes] trial = backend.get_trial(trial)[0] # remove('test_parallel.json') print(trial.metadata) assert trial.metadata.get('_update_count', 0) == woker_count + 1, 'Parallel write should wait for each other'
def test_e2e_server_socket(): remove('server_test.json') if SKIP_SERVER: return import socket s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(('', 0)) port = s.getsockname()[1] s.close() print('Starting Server') start_track_server('file://server_test.json', 'localhost', port) server = Process(target=start_track_server('file://server_test.json', 'localhost', port)) server.start() print('Starting client') end_to_end_train(f'socket://*****:*****@localhost:{port}') remove('server_test.json')
def test_e2e_file(): remove('file_test.json') end_to_end_train('file://file_test.json') remove('file_test.json')
def test_e2e_pickled(count=2): remove('file.pkl') end_to_end_train('pickled://file.pkl') remove('file.pkl')
def test_e2e_pickled_2clients(count=2): remove('file.pkl') multi_client_launch('pickled://file.pkl', count) remove('file.pkl')
@pytest.mark.skipif(is_travis(), reason='Travis is too slow') def test_e2e_ephemeral_2clients(count=2): multi_client_launch('ephemeral:', count) @pytest.mark.skipif(is_travis(), reason='Travis is too slow') def test_e2e_mongodb(count=2): end_to_end_train('mongodb://*****:*****@pytest.mark.skipif(is_travis(), reason='Travis is too slow') def test_e2e_pickled(count=2): remove('file.pkl') end_to_end_train('pickled://file.pkl') remove('file.pkl') @pytest.mark.skipif(is_travis(), reason='Travis is too slow') def test_e2e_ephemeral(count=2): end_to_end_train('ephemeral:') if __name__ == '__main__': # test_e2e_mongodb_2clients() # test_e2e_pickled_2clients() # test_e2e_ephemeral_2clients() remove('file.pkl') remove('file.pkl')