示例#1
0
@unittest.skipIf((not on_gcp())
                 or 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
                 'Skipping due to userinput or GCP Not detected')
class GCloudArtifactStoreTest(ArtifactStoreTest, unittest.TestCase):
    def get_store(self, config_name=None):
        store = ArtifactStoreTest.get_store(self,
                                            'test_config_gcloud_storage.yaml')
        self.assertTrue(isinstance(store, GCloudArtifactStore))
        return store

    def get_qualified_location_prefix(self):
        store = self.get_store()
        return "gs://" + store.get_bucket() + "/"


@unittest.skipIf(not on_aws(), 'User indicated not on aws')
class UserIndicatedOnAWSTest(unittest.TestCase):
    def test_on_enviornment(self):
        self.assertTrue(has_aws_credentials())


@unittest.skipIf(_get_provider() and not has_aws_credentials(),
                 'Skipping due to userinput or AWS Not detected')
class S3ArtifactStoreTest(ArtifactStoreTest, unittest.TestCase):
    def get_store(self, config_name=None):
        store = ArtifactStoreTest.get_store(self, 'test_config.yaml')
        self.assertTrue(isinstance(store, TartifactStore))
        return store

    def get_qualified_location_prefix(self):
        store = self.get_store()
示例#2
0
            fb._set(key_path, random_str)
            self.assertTrue(fb._get(key_path) is None)
            remove_all_keys()

    def test_get_set_firebase_bad(self):
        # smoke test to make sure access to a database at wrong
        # url is reported, but does not crash the system
        with self.get_provider('test_bad_config.yaml') as fb:
            response = fb._get("test/hello")
            self.assertTrue(response is None)

            fb._set("test/hello", "bla")


@unittest.skipIf(
    not on_aws(),
    'User indicated not on aws')
class UserIndicatedOnAWSTest(unittest.TestCase):
    def test_on_enviornment(self):
        self.assertTrue(has_aws_credentials())


@unittest.skipIf(
    not isinstance(KeyValueProviderTest().get_provider(), S3Provider),
    'Skipping due to provider is not S3Provider')
class S3ProviderTest(unittest.TestCase, KeyValueProviderTest):
    _multiprocess_shared_ = True

    def get_default_config_name(self):
        return 'test_config.yaml'
示例#3
0
class LocalWorkerTest(unittest.TestCase):
    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_runner_local(self):
        with stubtest_worker(self,
                             experiment_name='test_runner_local_' +
                             str(uuid.uuid4()),
                             config_name='test_config_http_client.yaml',
                             test_script='tf_hello_world.py',
                             runner_args=[],
                             script_args=['arg0'],
                             expected_output='[ 2.0 6.0 ]'):
            pass

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_args_conflict(self):
        with stubtest_worker(self,
                             experiment_name='test_runner_conflict_' +
                             str(uuid.uuid4()),
                             config_name='test_config.yaml',
                             runner_args=[],
                             test_script='conflicting_args.py',
                             script_args=['--experiment', 'aaa'],
                             expected_output='Experiment key = aaa'):
            pass

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_local_hyperparam(self):
        with stubtest_worker(self,
                             experiment_name='test_local_hyperparam' +
                             str(uuid.uuid4()),
                             runner_args=['--verbose=debug'],
                             config_name='test_config_http_client.yaml',
                             test_script='hyperparam_hello_world.py',
                             expected_output='0.3'):
            pass

        with stubtest_worker(
                self,
                experiment_name='test_local_hyperparam' + str(uuid.uuid4()),
                runner_args=[
                    '--verbose=debug', '--hyperparam=learning_rate=0.4'
                ],
                config_name='test_config_http_client.yaml',
                test_script='hyperparam_hello_world.py',
                expected_output='0.4'):
            pass

    @unittest.skip('peterz figure out the failure - happens intermittently ' +
                   'when running in parallel')
    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_local_worker_ce(self):
        tmpfile = os.path.join(tempfile.gettempdir(),
                               'tmpfile_ce_' + str(uuid.uuid4()) + '.txt')

        random_str1 = str(uuid.uuid4())

        with open(tmpfile, 'w') as f:
            f.write(random_str1)

        random_str2 = str(uuid.uuid4())
        experiment_name = 'test_local_worker_c' + str(uuid.uuid4())
        print("random_str1 = " + random_str1)
        print("random_str2 = " + random_str2)
        print("experiment_name = " + experiment_name)
        print("tmpfile = " + tmpfile)

        with stubtest_worker(
                self,
                experiment_name=experiment_name,
                runner_args=['--capture=' + tmpfile + ':f', '--verbose=debug'],
                config_name='test_config_http_client.yaml',
                test_script='art_hello_world.py',
                script_args=[random_str2],
                expected_output=random_str1,
                delete_when_done=False) as db:
            pass

        tmppath = db.get_artifact(
            db.get_experiment(experiment_name).artifacts['f'])

        with open(tmppath, 'r') as f:
            self.assertTrue(f.read() == random_str2)
        os.remove(tmppath)

        with stubtest_worker(
                self,
                experiment_name='test_local_worker_e' + str(uuid.uuid4()),
                runner_args=['--reuse={}/f:f'.format(experiment_name)],
                config_name='test_config_http_client.yaml',
                test_script='art_hello_world.py',
                script_args=[],
                expected_output=random_str2) as db:

            db.delete_experiment(experiment_name)

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_local_worker_co(self):
        tmpfile = os.path.join(tempfile.gettempdir(),
                               'tmpfile' + str(uuid.uuid4()) + '.txt')

        random_str = str(uuid.uuid4())
        with open(tmpfile, 'w') as f:
            f.write(random_str)

        # with get_local_queue_lock():
        with stubtest_worker(self,
                             experiment_name='test_local_worker_co' +
                             str(uuid.uuid4()),
                             runner_args=['--capture-once=' + tmpfile + ':f'],
                             config_name='test_config_http_client.yaml',
                             test_script='art_hello_world.py',
                             script_args=[],
                             expected_output=random_str):
            pass

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_local_worker_co_url(self):
        expected_str = 'Zabil zaryad ya v pushku tugo'
        url = 'https://storage.googleapis.com/studio-ed756.appspot.com/' + \
              'tests/url_artifact.txt'

        # with get_local_queue_lock():
        with stubtest_worker(self,
                             experiment_name='test_local_worker_co_url' +
                             str(uuid.uuid4()),
                             runner_args=['--capture-once=' + url + ':f'],
                             config_name='test_config_http_client.yaml',
                             test_script='art_hello_world.py',
                             script_args=[],
                             expected_output=expected_str):
            pass

    @unittest.skipIf(not on_aws(), 'User indicated not on aws')
    class UserIndicatedOnAWSTest(unittest.TestCase):
        def test_on_enviornment(self):
            self.assertTrue(has_aws_credentials())

    @unittest.skipIf((not on_aws()) or not has_aws_credentials(),
                     'Skipping due to userinput or AWS Not detected')
    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_local_worker_co_s3(self):
        expected_str = 'No4 ulica fonar apteka, bessmyslennyj i tusklyj svet'
        s3loc = 's3://studioml-artifacts/tests/download_test/download_test.txt'

        # with get_local_queue_lock():
        with stubtest_worker(self,
                             experiment_name='test_local_worker_co_s3' +
                             str(uuid.uuid4()),
                             runner_args=['--capture-once=' + s3loc + ':f'],
                             config_name='test_config_http_client.yaml',
                             test_script='art_hello_world.py',
                             script_args=[],
                             expected_output=expected_str):
            pass

    @unittest.skipIf(keras is None, 'keras is required for this test')
    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_save_get_model(self):
        experiment_name = 'test_save_get_model' + str(uuid.uuid4())
        # with get_local_queue_lock():
        with stubtest_worker(self,
                             experiment_name=experiment_name,
                             runner_args=[],
                             config_name='test_config_http_client.yaml',
                             test_script='save_model.py',
                             script_args=[],
                             expected_output='',
                             delete_when_done=False,
                             test_output=False) as db:

            experiment = db.get_experiment(experiment_name)
            saved_model = experiment.get_model(db)

            v = np.random.rand(1, 2)
            prediction = saved_model.predict(v)
            expected = v * 2

            self.assertTrue(np.isclose(prediction, expected).all())

            db.delete_experiment(experiment)

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_stop_experiment(self):
        my_path = os.path.dirname(os.path.realpath(__file__))

        logger = logs.getLogger('test_stop_experiment')
        logger.setLevel(10)

        config_name = os.path.join(my_path, 'test_config_http_client.yaml')
        key = 'test_stop_experiment' + str(uuid.uuid4())

        with model.get_db_provider(model.get_config(config_name)) as db:
            try:
                db.delete_experiment(key)
            except Exception:
                pass

            p = subprocess.Popen([
                'studio', 'run', '--config=' + config_name, '--experiment=' +
                key, '--force-git', '--verbose=debug', 'stop_experiment.py'
            ],
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.STDOUT,
                                 cwd=my_path)

            # wait till experiment spins up
            experiment = None
            while experiment is None or experiment.status == 'waiting':
                time.sleep(1)
                try:
                    experiment = db.get_experiment(key)
                except BaseException:
                    pass

            logger.info('Stopping experiment')
            db.stop_experiment(key)
            pout, _ = p.communicate()

            if pout:
                logger.debug("studio run output: \n" + pout.decode())

            db.delete_experiment(key)

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_experiment_maxduration(self):
        my_path = os.path.dirname(os.path.realpath(__file__))

        logger = logs.getLogger('test_experiment_maxduration')
        logger.setLevel(10)

        config_name = os.path.join(my_path, 'test_config_http_client.yaml')
        key = 'test_experiment_maxduration' + str(uuid.uuid4())

        with model.get_db_provider(model.get_config(config_name)) as db:
            try:
                db.delete_experiment(key)
            except Exception:
                pass

            p = subprocess.Popen([
                'studio', 'run', '--config=' + config_name,
                '--experiment=' + key, '--force-git', '--verbose=debug',
                '--max-duration=10s', 'stop_experiment.py'
            ],
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.STDOUT,
                                 cwd=my_path)

            pout, _ = p.communicate()
            if pout:
                logger.debug("studio run output: \n" + pout.decode())

            db.delete_experiment(key)

    @timeout(TEST_TIMEOUT, use_signals=False)
    def test_experiment_lifetime(self):
        my_path = os.path.dirname(os.path.realpath(__file__))

        logger = logs.getLogger('test_experiment_lifetime')
        logger.setLevel(10)

        config_name = os.path.join(my_path, 'test_config_http_client.yaml')
        key = 'test_experiment_lifetime' + str(uuid.uuid4())

        with model.get_db_provider(model.get_config(config_name)) as db:
            try:
                db.delete_experiment(key)
            except Exception:
                pass

            p = subprocess.Popen([
                'studio', 'run', '--config=' + config_name,
                '--experiment=' + key, '--force-git', '--verbose=debug',
                '--lifetime=-10m', 'stop_experiment.py'
            ],
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.STDOUT,
                                 cwd=my_path)

            pout, _ = p.communicate()

            if pout:
                logger.debug("studio run output: \n" + pout.decode())

            db.delete_experiment(key)
示例#4
0
        with stubtest_worker(self,
                             experiment_name=experiment_name,
                             runner_args=[
                                 '--cloud=gcspot', '--force-git',
                                 '--cloud-timeout=120',
                                 '--container=shub://vsoch/hello-world'
                             ],
                             config_name='test_config_http_client.yaml',
                             test_script='',
                             script_args=[],
                             expected_output='RaawwWWWWWRRRR!!',
                             test_workspace=False):
            pass


@unittest.skipIf(not on_aws(), 'User indicated not on aws')
class UserIndicatedOnAWSTest(unittest.TestCase):
    def test_on_enviornment(self):
        self.assertTrue(has_aws_credentials())


@unittest.skipIf((not on_aws()) or not has_aws_credentials(),
                 'Skipping due to userinput or AWS Not detected')
class EC2WorkerTest(unittest.TestCase):
    _multiprocess_shared_ = True

    def get_worker_manager(self):
        return EC2WorkerManager()

    @timeout(CLOUD_TEST_TIMEOUT, use_signals=False)
    def test_worker(self):
示例#5
0
            fb._set(key_path, random_str)
            self.assertTrue(fb._get(key_path) is None)
            remove_all_keys()

    def test_get_set_firebase_bad(self):
        # smoke test to make sure access to a database at wrong
        # url is reported, but does not crash the system
        with self.get_provider('test_bad_config.yaml') as fb:
            response = fb._get("test/hello")
            self.assertTrue(response is None)

            fb._set("test/hello", "bla")


@unittest.skipIf(
    not on_aws(),
    'User indicated not on aws')
class UserIndicatedOnAWSTest(unittest.TestCase):
    def test_on_enviornment(self):
        self.assertTrue(has_aws_credentials())


@unittest.skipIf(
    (not on_aws()) or not has_aws_credentials(),
    'Skipping due to userinput or AWS Not detected')
class S3ProviderTest(unittest.TestCase, KeyValueProviderTest):
    _multiprocess_shared_ = True

    def get_default_config_name(self):
        return 'test_config_s3.yaml'