Пример #1
0
    def setUpClass(self):
        if not has_aws_credentials():
            return
        print "Starting up the API server"
        self.port = randint(5000, 9000)

        # self.app.run(port=self.port, debug=True)
        # self.serverp.start()
        self.server_config_file = os.path.join(
            os.path.dirname(
                os.path.realpath(__file__)),
            'test_config_http_server.yaml')

        self.client_config_file = os.path.join(
            os.path.dirname(
                os.path.realpath(__file__)),
            'test_config_http_client.yaml')

        self.serverp = subprocess.Popen([
            'studio-ui',
            '--port=' + str(self.port),
            '--verbose=debug',
            '--config=' + self.server_config_file,
            '--host=localhost'])

        time.sleep(25)
Пример #2
0
    def test_two_experiments_apiserver(self):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests',
                                   'test_config_http_client.yaml')

        files_in_workspace = os.listdir(mypath)
        files = {
            f: os.path.join(mypath, f)
            for f in files_in_workspace
            if os.path.isfile(os.path.join(mypath, f))
        }

        files['url'] = _file_url

        if has_aws_credentials():
            files['s3'] = _file_s3

        self._run_test_files(n_experiments=2, files=files, config=config_path)
Пример #3
0
 def test_on_enviornment(self):
     self.assertTrue(has_aws_credentials())
Пример #4
0
                                            'test_config_gcloud_storage.yaml')
        self.assertTrue(isinstance(store, GCloudArtifactStore))
        return store

    def get_qualified_location_prefix(self):
        store = self.get_store()
        return "gs://" + store.get_bucket() + "/"


@unittest.skipIf(not on_aws(), 'User indicated not on aws')
class UserIndicatedOnAWSTest(unittest.TestCase):
    def test_on_enviornment(self):
        self.assertTrue(has_aws_credentials())


@unittest.skipIf(_get_provider() and not has_aws_credentials(),
                 'Skipping due to userinput or AWS Not detected')
class S3ArtifactStoreTest(ArtifactStoreTest, unittest.TestCase):
    def get_store(self, config_name=None):
        store = ArtifactStoreTest.get_store(self, 'test_config.yaml')
        self.assertTrue(isinstance(store, TartifactStore))
        return store

    def get_qualified_location_prefix(self):
        store = self.get_store()
        return store.get_qualified_location("")


if __name__ == "__main__":
    unittest.main()
Пример #5
0
        with stubtest_worker(self,
                             experiment_name=experiment_name,
                             runner_args=[
                                 '--cloud=gcspot', '--force-git',
                                 '--cloud-timeout=120',
                                 '--container=shub://vsoch/hello-world'
                             ],
                             config_name='test_config_http_client.yaml',
                             test_script='',
                             script_args=[],
                             expected_output='RaawwWWWWWRRRR!!',
                             test_workspace=False):
            pass


@unittest.skipIf(not has_aws_credentials(),
                 'boto3 not present, won\'t be able to use AWS API')
class EC2WorkerTest(unittest.TestCase):
    _multiprocess_shared_ = True

    def get_worker_manager(self):
        return EC2WorkerManager()

    @timeout(CLOUD_TEST_TIMEOUT, use_signals=False)
    def test_worker(self):
        experiment_name = 'test_ec2_worker_' + str(uuid.uuid4())
        with stubtest_worker(
                self,
                experiment_name=experiment_name,
                runner_args=[
                    '--cloud=ec2', '--force-git', '--gpus=1',
Пример #6
0
                 'GOOGLE_APPLICATION_CREDENTIALS environment ' +
                 'variable not set, won'
                 ' be able to use google cloud')
class GCloudArtifactStoreTest(ArtifactStoreTest, unittest.TestCase):
    def get_store(self, config_name=None):
        store = ArtifactStoreTest.get_store(self,
                                            'test_config_gcloud_storage.yaml')
        self.assertTrue(isinstance(store, GCloudArtifactStore))
        return store

    def get_qualified_location_prefix(self):
        store = self.get_store()
        return "gs://" + store.get_bucket() + "/"


@unittest.skipIf(not has_aws_credentials(), 'AWS credentials not found, '
                 'won'
                 ' be able to use S3')
class S3ArtifactStoreTest(ArtifactStoreTest, unittest.TestCase):
    def get_store(self, config_name=None):
        store = ArtifactStoreTest.get_store(self,
                                            'test_config_s3_storage.yaml')
        self.assertTrue(isinstance(store, S3ArtifactStore))
        return store

    def get_qualified_location_prefix(self):
        store = self.get_store()
        endpoint = urlparse(boto3.client('s3')._endpoint.host)
        return "s3://" + endpoint.netloc + "/" + store.bucket + "/"

Пример #7
0
class CompletionServiceTest(unittest.TestCase):
    def _run_test(self,
                  args=None,
                  files={},
                  jobfile=None,
                  expected_results=None,
                  **csargs):

        if not (any(csargs)):
            return

        jobfile = self.get_jobfile(jobfile or 'completion_service_testfunc.py')

        args = args or [0, 1]

        expected_results = expected_results or args
        submission_indices = {}
        n_experiments = len(args)
        experimentId = str(uuid.uuid4())

        with CompletionService(experimentId, **csargs) as cs:
            for i in range(0, n_experiments):
                key = cs.submitTaskWithFiles(jobfile, args[i], files)
                submission_indices[key] = i

            for i in range(0, n_experiments):
                result = cs.getResults(blocking=True)
                self.assertEquals(
                    result[1], expected_results[submission_indices[result[0]]])

    def _run_test_files(self, files, n_experiments=2, **csargs):

        expected_results = [(i, self._get_file_hashes(files))
                            for i in range(n_experiments)]
        args = range(n_experiments)
        self._run_test(args=args,
                       files=files,
                       jobfile='completion_service_testfunc_files.py',
                       expected_results=expected_results,
                       **csargs)

    def _run_test_myfiles(self, n_experiments=2, **csargs):

        mypath = os.path.dirname(os.path.realpath(__file__))

        files_in_workspace = os.listdir(mypath)
        files = {
            f: os.path.join(mypath, f)
            for f in files_in_workspace
            if os.path.isfile(os.path.join(mypath, f))
        }

        files['url'] = _file_url

        # TODO peterz enable passing aws credentials to google workers
        # if has_aws_credentials():
        #     files['s3'] = _file_s3

        expected_results = [(i, self._get_file_hashes(files))
                            for i in range(n_experiments)]

        args = range(n_experiments)
        self._run_test(args=args,
                       files=files,
                       jobfile='completion_service_testfunc_files.py',
                       expected_results=expected_results,
                       **csargs)

    def _get_file_hashes(self, files):
        retval = {}
        for k, v in six.iteritems(files):
            if '://' in v:
                tmpfilename = os.path.join(tempfile.gettempdir(),
                                           rand_string(10))
                download_file(v, tmpfilename)
                retval[k] = filehash(tmpfilename, hashobj=hashlib.md5())
                os.remove(tmpfilename)
            else:
                retval[k] = filehash(v, hashobj=hashlib.md5())

        return retval

    @unittest.skipIf(not has_aws_credentials(),
                     'AWS credentials needed for this test')
    @timeout(CLOUD_TEST_TIMEOUT, use_signals=False)
    def test_two_experiments_ec2(self):
        self._run_test(config=self.get_config_path(),
                       cloud_timeout=100,
                       cloud='ec2')

    @unittest.skipIf(not has_aws_credentials(),
                     'AWS credentials needed for this test')
    @timeout(CLOUD_TEST_TIMEOUT, use_signals=False)
    def test_two_experiments_ec2spot(self):
        self._run_test_myfiles(
            n_experiments=2,
            config=self.get_config_path(),
            cloud_timeout=100,
            cloud='ec2spot',
        )

    @timeout(LOCAL_TEST_TIMEOUT, use_signals=False)
    def test_two_experiments_apiserver(self):
        self._run_test_myfiles(n_experiments=2,
                               config=self.get_config_path(),
                               cloud_timeout=LOCAL_TEST_TIMEOUT)

    @unittest.skipIf('GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
                     'Need GOOGLE_APPLICATION_CREDENTIALS env variable to' +
                     'use google cloud')
    @timeout(CLOUD_TEST_TIMEOUT, use_signals=False)
    def test_two_experiments_gcspot(self):
        self._run_test_myfiles(n_experiments=2,
                               config=self.get_config_path(),
                               cloud='gcspot')

    @unittest.skipIf('GOOGLE_APPLICATION_CREDENTIALS_DC'
                     not in os.environ.keys(),
                     'Need GOOGLE_APPLICATION_CREDENTIALS_DC env variable to' +
                     'use google cloud')
    @timeout(CLOUD_TEST_TIMEOUT, use_signals=False)
    def test_two_experiments_datacenter(self):
        oldcred = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS')
        os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = \
            os.environ['GOOGLE_APPLICATION_CREDENTIALS_DC']

        queue_name = 'test_queue_' + str(uuid.uuid4())

        self._run_test_myfiles(
            config=self.get_config_path('test_config_datacenter.yaml'),
            queue=queue_name,
            shutdown_del_queue=True)
        if oldcred:
            os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = oldcred

    @unittest.skipIf('GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
                     'Need GOOGLE_APPLICATION_CREDENTIALS env variable to' +
                     'use google cloud')
    @timeout(CLOUD_TEST_TIMEOUT, use_signals=False)
    def test_two_experiments_gcloud(self):
        self._run_test_myfiles(n_experiments=2,
                               config=self.get_config_path(),
                               cloud='gcloud')

    @timeout(LOCAL_TEST_TIMEOUT, use_signals=False)
    def test_studiolink(self):

        experiment_id = str(uuid.uuid4())
        arg1 = random.randint(0, 10000)

        jobfile = self.get_jobfile('completion_service_testfunc_saveload.py')
        with CompletionService(experiment_id,
                               config=self.get_config_path(),
                               cloud_timeout=LOCAL_TEST_TIMEOUT) as cs:
            key1 = cs.submitTask(jobfile, arg1)
            ret_key1, result1 = cs.getResults()
            self.assertEquals(key1, ret_key1)
            self.assertEquals(result1, arg1)

            files = {'model': 'studio://{}/modeldir'.format(key1)}

            key2 = cs.submitTaskWithFiles(jobfile, None, files=files)
            ret_key2, result2 = cs.getResults()
            self.assertEquals(key2, ret_key2)
            self.assertEquals(result2, arg1 + 1)

    @timeout(LOCAL_TEST_TIMEOUT, use_signals=False)
    def test_restart(self):
        experiment_id = str(uuid.uuid4())
        arg1 = random.randint(0, 10000)

        jobfile = self.get_jobfile('completion_service_testfunc_saveload.py')
        with CompletionService(experiment_id,
                               config=self.get_config_path(),
                               cloud_timeout=LOCAL_TEST_TIMEOUT) as cs:
            key1 = cs.submitTask(jobfile, arg1, job_id=0)
            ret_key1, result1 = cs.getResults()
            self.assertEquals(key1, ret_key1)
            self.assertEquals(result1, arg1)

            key2 = cs.submitTask(jobfile, None, job_id=0)
            ret_key2, result2 = cs.getResults()
            self.assertEquals(key2, ret_key2)
            self.assertEquals(result2, arg1 + 1)

    def get_config_path(self, config_name='test_config_http_client.yaml'):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests',
                                   'test_config_http_client.yaml')

        return config_path

    def get_jobfile(self, filename='completion_service_testfunc.py'):
        mypath = os.path.dirname(os.path.realpath(__file__))
        jobfile = os.path.join(mypath, filename)

        return jobfile
class LocalWorkerTest(unittest.TestCase):

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_runner_local(self):
        with stubtest_worker(
            self,
            experiment_name='test_runner_local_' + str(uuid.uuid4()),
            config_name='test_config_http_client.yaml',
            test_script='tf_hello_world.py',
            runner_args=[],
            script_args=['arg0'],
            expected_output='[ 2.0 6.0 ]'
        ):
            pass

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_args_conflict(self):
        with stubtest_worker(
            self,
            experiment_name='test_runner_conflict_' + str(uuid.uuid4()),
            config_name='test_config.yaml',
            runner_args=[],
            test_script='conflicting_args.py',
            script_args=['--experiment', 'aaa'],
            expected_output='Experiment key = aaa'
        ):
            pass

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_local_hyperparam(self):
        with stubtest_worker(
            self,
            experiment_name='test_local_hyperparam' + str(uuid.uuid4()),
            runner_args=['--verbose=debug'],
            config_name='test_config_http_client.yaml',
            test_script='hyperparam_hello_world.py',
            expected_output='0.3'
        ):
            pass

        with stubtest_worker(
            self,
            experiment_name='test_local_hyperparam' + str(uuid.uuid4()),
            runner_args=[
                '--verbose=debug',
                '--hyperparam=learning_rate=0.4'
            ],
            config_name='test_config_http_client.yaml',
            test_script='hyperparam_hello_world.py',
            expected_output='0.4'
        ):
            pass

    @unittest.skip('peterz figure out the failure - happens intermittently ' +
                   'when running in parallel')
    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_local_worker_ce(self):
        tmpfile = os.path.join(tempfile.gettempdir(),
                               'tmpfile_ce_' +
                               str(uuid.uuid4()) + '.txt')

        random_str1 = str(uuid.uuid4())

        with open(tmpfile, 'w') as f:
            f.write(random_str1)

        random_str2 = str(uuid.uuid4())
        experiment_name = 'test_local_worker_c' + str(uuid.uuid4())
        print("random_str1 = " + random_str1)
        print("random_str2 = " + random_str2)
        print("experiment_name = " + experiment_name)
        print("tmpfile = " + tmpfile)

        with stubtest_worker(
            self,
            experiment_name=experiment_name,
            runner_args=['--capture=' + tmpfile + ':f',
                         '--verbose=debug'],
            config_name='test_config_http_client.yaml',
            test_script='art_hello_world.py',
            script_args=[random_str2],
            expected_output=random_str1,
            delete_when_done=False
        ) as db:
            pass

        tmppath = db.get_artifact(
            db.get_experiment(experiment_name).artifacts['f']
        )

        with open(tmppath, 'r') as f:
            self.assertTrue(f.read() == random_str2)
        os.remove(tmppath)

        with stubtest_worker(
            self,
            experiment_name='test_local_worker_e' + str(uuid.uuid4()),
            runner_args=['--reuse={}/f:f'.format(experiment_name)],
            config_name='test_config_http_client.yaml',
            test_script='art_hello_world.py',
            script_args=[],
            expected_output=random_str2
        ) as db:

            db.delete_experiment(experiment_name)

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_local_worker_co(self):
        tmpfile = os.path.join(tempfile.gettempdir(),
                               'tmpfile' +
                               str(uuid.uuid4()) + '.txt')

        random_str = str(uuid.uuid4())
        with open(tmpfile, 'w') as f:
            f.write(random_str)

        # with get_local_queue_lock():
        with stubtest_worker(
            self,
            experiment_name='test_local_worker_co' + str(uuid.uuid4()),
            runner_args=['--capture-once=' + tmpfile + ':f'],
            config_name='test_config_http_client.yaml',
            test_script='art_hello_world.py',
            script_args=[],
            expected_output=random_str
        ):
            pass

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_local_worker_co_url(self):
        expected_str = 'Zabil zaryad ya v pushku tugo'
        url = 'https://storage.googleapis.com/studio-ed756.appspot.com/' + \
              'tests/url_artifact.txt'

        # with get_local_queue_lock():
        with stubtest_worker(
            self,
            experiment_name='test_local_worker_co_url' + str(uuid.uuid4()),
            runner_args=['--capture-once=' + url + ':f'],
            config_name='test_config_http_client.yaml',
            test_script='art_hello_world.py',
            script_args=[],
            expected_output=expected_str
        ):
            pass

    @unittest.skipIf(
        not has_aws_credentials(),
        'AWS credentials not found, cannot download s3://-like links')
    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_local_worker_co_s3(self):
        expected_str = 'No4 ulica fonar apteka, bessmyslennyj i tusklyj svet'
        s3loc = 's3://studioml-artifacts/tests/download_test/download_test.txt'

        # with get_local_queue_lock():
        with stubtest_worker(
            self,
            experiment_name='test_local_worker_co_s3' + str(uuid.uuid4()),
            runner_args=['--capture-once=' + s3loc + ':f'],
            config_name='test_config_http_client.yaml',
            test_script='art_hello_world.py',
            script_args=[],
            expected_output=expected_str
        ):
            pass

    @unittest.skipIf(keras is None,
                     'keras is required for this test')
    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_save_get_model(self):
        experiment_name = 'test_save_get_model' + str(uuid.uuid4())
        # with get_local_queue_lock():
        with stubtest_worker(
            self,
            experiment_name=experiment_name,
            runner_args=[],
            config_name='test_config_http_client.yaml',
            test_script='save_model.py',
            script_args=[],
            expected_output='',
            delete_when_done=False,
            test_output=False
        ) as db:

            experiment = db.get_experiment(experiment_name)
            saved_model = experiment.get_model(db)

            v = np.random.rand(1, 2)
            prediction = saved_model.predict(v)
            expected = v * 2

            self.assertTrue(np.isclose(prediction, expected).all())

            db.delete_experiment(experiment)

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_stop_experiment(self):
        my_path = os.path.dirname(os.path.realpath(__file__))

        logger = logs.getLogger('test_stop_experiment')
        logger.setLevel(10)

        config_name = os.path.join(my_path, 'test_config_http_client.yaml')
        key = 'test_stop_experiment' + str(uuid.uuid4())

        with model.get_db_provider(model.get_config(config_name)) as db:
            try:
                db.delete_experiment(key)
            except Exception:
                pass

            p = subprocess.Popen(['studio', 'run',
                                  '--config=' + config_name,
                                  '--experiment=' + key,
                                  '--force-git',
                                  '--verbose=debug',
                                  'stop_experiment.py'],
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.STDOUT,
                                 cwd=my_path)

            # wait till experiment spins up
            experiment = None
            while experiment is None or experiment.status == 'waiting':
                time.sleep(1)
                try:
                    experiment = db.get_experiment(key)
                except BaseException:
                    pass

            logger.info('Stopping experiment')
            db.stop_experiment(key)
            pout, _ = p.communicate()

            if pout:
                logger.debug("studio run output: \n" + pout.decode())

            db.delete_experiment(key)

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_experiment_maxduration(self):
        my_path = os.path.dirname(os.path.realpath(__file__))

        logger = logs.getLogger('test_experiment_maxduration')
        logger.setLevel(10)

        config_name = os.path.join(my_path, 'test_config_http_client.yaml')
        key = 'test_experiment_maxduration' + str(uuid.uuid4())

        with model.get_db_provider(model.get_config(config_name)) as db:
            try:
                db.delete_experiment(key)
            except Exception:
                pass

            p = subprocess.Popen(['studio', 'run',
                                  '--config=' + config_name,
                                  '--experiment=' + key,
                                  '--force-git',
                                  '--verbose=debug',
                                  '--max-duration=10s',
                                  'stop_experiment.py'],
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.STDOUT,
                                 cwd=my_path)

            pout, _ = p.communicate()
            if pout:
                logger.debug("studio run output: \n" + pout.decode())

            db.delete_experiment(key)

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_experiment_lifetime(self):
        my_path = os.path.dirname(os.path.realpath(__file__))

        logger = logs.getLogger('test_experiment_lifetime')
        logger.setLevel(10)

        config_name = os.path.join(my_path, 'test_config_http_client.yaml')
        key = 'test_experiment_lifetime' + str(uuid.uuid4())

        with model.get_db_provider(model.get_config(config_name)) as db:
            try:
                db.delete_experiment(key)
            except Exception:
                pass

            p = subprocess.Popen(['studio', 'run',
                                  '--config=' + config_name,
                                  '--experiment=' + key,
                                  '--force-git',
                                  '--verbose=debug',
                                  '--lifetime=-10m',
                                  'stop_experiment.py'],
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.STDOUT,
                                 cwd=my_path)

            pout, _ = p.communicate()

            if pout:
                logger.debug("studio run output: \n" + pout.decode())

            db.delete_experiment(key)
Пример #9
0
                             ],
                             config_name='test_config_http_client.yaml',
                             test_script='',
                             script_args=[],
                             expected_output='RaawwWWWWWRRRR!!',
                             test_workspace=False):
            pass


@unittest.skipIf(not on_aws(), 'User indicated not on aws')
class UserIndicatedOnAWSTest(unittest.TestCase):
    def test_on_enviornment(self):
        self.assertTrue(has_aws_credentials())


@unittest.skipIf((not on_aws()) or not has_aws_credentials(),
                 'Skipping due to userinput or AWS Not detected')
class EC2WorkerTest(unittest.TestCase):
    _multiprocess_shared_ = True

    def get_worker_manager(self):
        return EC2WorkerManager()

    @timeout(CLOUD_TEST_TIMEOUT, use_signals=False)
    def test_worker(self):
        experiment_name = 'test_ec2_worker_' + str(uuid.uuid4())
        with stubtest_worker(
                self,
                experiment_name=experiment_name,
                runner_args=[
                    '--cloud=ec2', '--force-git', '--gpus=1',
Пример #10
0
    def tearDownClass(self):
        if not has_aws_credentials():
            return

        print("Shutting down the API server")
        self.serverp.kill()
Пример #11
0
class CompletionServiceTest(unittest.TestCase):
    def test_two_experiments_with_cs_args(self, **kwargs):
        if not (any(kwargs)):
            return
        mypath = os.path.dirname(os.path.realpath(__file__))
        experimentId = str(uuid.uuid4())
        n_experiments = 2
        results = {}
        expected_results = {}
        with CompletionService(experimentId, **kwargs) as cs:
            for i in range(0, n_experiments):
                key = cs.submitTask(
                    os.path.join(mypath, 'completion_service_func.py'), [i])
                expected_results[key] = [i]

            for i in range(0, n_experiments):
                result = cs.getResults(blocking=True)
                results[result[0]] = result[1]

        self.assertEquals(results, expected_results)

    def test_two_experiments(self):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests', 'test_config.yaml')

        self.test_two_experiments_with_cs_args(config=config_path,
                                               cloud_timeout=60)

    @unittest.skipIf(not has_aws_credentials(),
                     'AWS credentials needed for this test')
    def test_two_experiments_ec2(self):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests', 'test_config.yaml')

        self.test_two_experiments_with_cs_args(config=config_path,
                                               cloud_timeout=100,
                                               cloud='ec2')

    @unittest.skipIf(not has_aws_credentials(),
                     'AWS credentials needed for this test')
    def test_two_experiments_ec2spot(self):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests', 'test_config.yaml')

        self.test_two_experiments_with_cs_args(config=config_path,
                                               cloud_timeout=100,
                                               cloud='ec2spot')

    def test_two_experiments_apiserver(self):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests',
                                   'test_config_http_client.yaml')

        self.test_two_experiments_with_cs_args(config=config_path)

    @unittest.skipIf('GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
                     'Need GOOGLE_APPLICATION_CREDENTIALS env variable to' +
                     'use google cloud')
    def test_two_experiments_gcloud(self):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests', 'test_config.yaml')

        self.test_two_experiments_with_cs_args(config=config_path,
                                               cloud='gcloud')

    # @unittest.skip('TODO peterz scale down or fix')
    @unittest.skipIf(not has_aws_credentials(),
                     'AWS credentials needed for this test')
    def test_many_experiments_ec2(self):
        experimentId = str(uuid.uuid4())
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests', 'test_config.yaml')

        n_experiments = 100
        num_workers = 30

        print("Executing {} tasks with {} workers".format(
            n_experiments, num_workers))

        results = {}
        expected_results = {}

        logger = logging.getLogger('test_1k_experiments_ec2')
        logger.setLevel(10)

        with CompletionService(experimentId,
                               config=config_path,
                               cloud='ec2spot',
                               num_workers=num_workers) as cs:

            def submit_task(i):
                key = cs.submitTaskWithFiles(
                    os.path.join(mypath, 'completion_service_func.py'), [i], {
                        'a': '/Users/peter.zhokhov/.bash_profile',
                        'p': '/Users/peter.zhokhov/.bash_profile'
                    })
                logger.info('Submitted task ' + str(i))
                expected_results[key] = [i]

            '''
            pool.map(submit_task, range(n_experiments))
            print("Submitted")
            pool.close()
            pool.join()
            '''
            for i in range(n_experiments):
                submit_task(i)

            for i in range(0, n_experiments):
                print("Trying to get a result " + str(i))
                result = cs.getResults(blocking=True)
                logger.info('Received result ' + str(result))
                results[result[0]] = result[1]

        self.assertEquals(results, expected_results)

    @unittest.skipIf('GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
                     'Need GOOGLE_APPLICATION_CREDENTIALS env variable to' +
                     'use google cloud')
    def test_two_experiments_gcloud_nonspot(self):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests', 'test_config.yaml')

        self.test_two_experiments_with_cs_args(config=config_path,
                                               cloud='gcloud')
Пример #12
0
            response = fb._get("test/hello")
            self.assertTrue(response is None)

            fb._set("test/hello", "bla")


@unittest.skipIf(
    not on_aws(),
    'User indicated not on aws')
class UserIndicatedOnAWSTest(unittest.TestCase):
    def test_on_enviornment(self):
        self.assertTrue(has_aws_credentials())


@unittest.skipIf(
    (not on_aws()) or not has_aws_credentials(),
    'Skipping due to userinput or AWS Not detected')
class S3ProviderTest(unittest.TestCase, KeyValueProviderTest):
    _multiprocess_shared_ = True

    def get_default_config_name(self):
        return 'test_config_s3.yaml'


@unittest.skipIf(
    'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
    'google app credentials not found, cannot run test')
class GSProviderTest(unittest.TestCase, KeyValueProviderTest):
    def get_default_config_name(self):
        return 'test_config_gs.yaml'
Пример #13
0
        self.assertEquals(data, msg)


@unittest.skipIf(
    'GOOGLE_APPLICATION_CREDENTIALS' not in
    os.environ.keys(),
    'GOOGLE_APPLICATION_CREDENTIALS environment ' +
    'variable not set, won'' be able to use google ' +
    'PubSub')
class PubSubQueueTest(DistributedQueueTest, unittest.TestCase):
    _multiprocess_can_split_ = True

    def get_queue(self, name=None):
        return PubsubQueue(
            'pubsub_queue_test_' + str(uuid.uuid4()) if not name else name)


@unittest.skipIf(
    not has_aws_credentials(),
    "AWS credentials is not present, cannot use SQSQueue")
class SQSQueueTest(DistributedQueueTest, unittest.TestCase):
    _multiprocess_can_split_ = True

    def get_queue(self, name=None):
        return SQSQueue(
            'sqs_queue_test_' + str(uuid.uuid4()) if not name else name)


if __name__ == '__main__':
    unittest.main()
Пример #14
0
class CompletionServiceTest(unittest.TestCase):
    def _run_test(self,
                  args=None,
                  files={},
                  jobfile=None,
                  expected_results=None,
                  **csargs):

        if not (any(csargs)):
            return

        mypath = os.path.dirname(os.path.realpath(__file__))
        jobfile = os.path.join(mypath, jobfile
                               or 'completion_service_testfunc.py')

        args = args or [0, 1]

        expected_results = expected_results or args
        submission_indices = {}
        n_experiments = len(args)
        experimentId = str(uuid.uuid4())

        with CompletionService(experimentId, **csargs) as cs:
            for i in range(0, n_experiments):
                key = cs.submitTaskWithFiles(jobfile, args[i], files)
                submission_indices[key] = i

            for i in range(0, n_experiments):
                result = cs.getResults(blocking=True)
                self.assertEquals(
                    result[1], expected_results[submission_indices[result[0]]])

    def _run_test_files(self, files, n_experiments=2, **csargs):

        expected_results = [(i, self._get_file_hashes(files))
                            for i in range(n_experiments)]
        args = range(n_experiments)
        self._run_test(args=args,
                       files=files,
                       jobfile='completion_service_testfunc_files.py',
                       expected_results=expected_results,
                       **csargs)

    def _get_file_hashes(self, files):
        retval = {}
        for k, v in six.iteritems(files):
            if '://' in v:
                tmpfilename = os.path.join(tempfile.gettempdir(),
                                           rand_string(10))
                download_file(v, tmpfilename)
                retval[k] = filehash(tmpfilename, hashobj=hashlib.md5())
                os.remove(tmpfilename)
            else:
                retval[k] = filehash(v, hashobj=hashlib.md5())

        return retval

    @unittest.skipIf(not has_aws_credentials(),
                     'AWS credentials needed for this test')
    def test_two_experiments_ec2(self):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests',
                                   'test_config_http_client.yaml')

        self._run_test(config=config_path, cloud_timeout=100, cloud='ec2')

    @unittest.skipIf(not has_aws_credentials(),
                     'AWS credentials needed for this test')
    def test_two_experiments_ec2spot(self):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests',
                                   'test_config_http_client.yaml')

        files_in_workspace = os.listdir(mypath)
        files = {
            f: os.path.join(mypath, f)
            for f in files_in_workspace
            if os.path.isfile(os.path.join(mypath, f))
        }

        files['url'] = _file_url
        files['s3'] = _file_s3

        self._run_test_files(
            files=files,
            n_experiments=2,
            config=config_path,
            cloud_timeout=100,
            cloud='ec2spot',
        )

    def test_two_experiments_apiserver(self):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests',
                                   'test_config_datacenter.yaml')

        files_in_workspace = os.listdir(mypath)
        files = {
            f: os.path.join(mypath, f)
            for f in files_in_workspace
            if os.path.isfile(os.path.join(mypath, f))
        }

        files['url'] = _file_url

        if has_aws_credentials():
            files['s3'] = _file_s3

        with get_local_queue_lock():
            self._run_test_files(n_experiments=2,
                                 files=files,
                                 config=config_path)

    @unittest.skipIf('GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
                     'Need GOOGLE_APPLICATION_CREDENTIALS env variable to' +
                     'use google cloud')
    def test_two_experiments_gcspot(self):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests',
                                   'test_config_http_client.yaml')

        files_in_workspace = os.listdir(mypath)
        files = {
            f: os.path.join(mypath, f)
            for f in files_in_workspace
            if os.path.isfile(os.path.join(mypath, f))
        }

        files['url'] = _file_url

        self._run_test_files(files=files,
                             n_experiments=2,
                             config=config_path,
                             cloud='gcspot')

    @unittest.skipIf('GOOGLE_APPLICATION_CREDENTIALS_DC'
                     not in os.environ.keys(),
                     'Need GOOGLE_APPLICATION_CREDENTIALS_DC env variable to' +
                     'use google cloud')
    def test_two_experiments_datacenter(self):
        oldcred = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS')
        os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = \
            os.environ['GOOGLE_APPLICATION_CREDENTIALS_DC']

        mypath = os.path.dirname(os.path.realpath(__file__))
        queue_name = 'test_queue_' + str(uuid.uuid4())
        config_path = os.path.join(mypath, '..', 'tests',
                                   'test_config_datacenter.yaml')

        files_in_workspace = os.listdir(mypath)
        files = {
            f: os.path.join(mypath, f)
            for f in files_in_workspace
            if os.path.isfile(os.path.join(mypath, f))
        }

        # files['url'] = _file_url
        files['s3'] = _file_s3

        self._run_test_files(files=files,
                             config=config_path,
                             queue=queue_name,
                             shutdown_del_queue=True)
        if oldcred:
            os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = oldcred

    @unittest.skipIf('GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
                     'Need GOOGLE_APPLICATION_CREDENTIALS env variable to' +
                     'use google cloud')
    def test_two_experiments_gcloud(self):
        mypath = os.path.dirname(os.path.realpath(__file__))
        config_path = os.path.join(mypath, '..', 'tests',
                                   'test_config_http_client.yaml')

        files_in_workspace = os.listdir(mypath)
        files = {
            f: os.path.join(mypath, f)
            for f in files_in_workspace
            if os.path.isfile(os.path.join(mypath, f))
        }

        files['url'] = _file_url

        self._run_test_files(files=files,
                             n_experiments=2,
                             config=config_path,
                             cloud='gcloud')