def test_load_non_existant_data(self):
        """
        Test IOError when trying to load a non-existant dataset.
        """
        dm = DataManager(data_endpoint='test/data/basic.csv',
                         target_column='species')

        # Should raise an exception when trying to load a dataset that doesn't exist.
        with self.assertRaises(botocore.exceptions.ClientError):
            dm.load()
 def test_load_first_mb_s3(self):
     """
     Ensure dataloader can load just the for n_bytes from s3
     """
     dm = DataManager(data_endpoint='test/data/basic.csv',
                      target_column='species')
     dm.load(n_bytes=1024)
     size = sys.getsizeof(dm.X)
     self.assertLess(size,
                     1512,
                     msg=f'Size of data is > 1.5mb; it is {size} bytes')
    def test_list_objects(self):
        """Test listing of objects on s3"""
        resp = DataManager.list_s3_datasets(bucket='test', dir='data/')
        self.assertIsInstance(resp, list)
        self.assertTrue(any((v['key'] == 'data/basic.csv' for v in resp)))

        # Just for sanity, load the given key for this 'test' bucket
        data = DataManager(data_endpoint=f'test/{resp[0]["key"]}',
                           target_column='species')
        data.load()
        self.assertGreater(data.X.shape[0], 0)
 def test_load_first_mb_http(self):
     """
     Ensure dataloader can load just the for n_bytes from s3
     """
     dm = DataManager(
         data_endpoint=
         'https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv',
         target_column='species')
     dm.load(n_bytes=1024)
     size = sys.getsizeof(dm.X)
     self.assertLess(size,
                     1512,
                     msg=f'Size of data is > 1.5mb; it is {size} bytes')
 def test_load_from_http(self):
     """
     Ensure dataloader can load a dataset via http
     """
     dm = DataManager(
         data_endpoint=
         'https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv',
         target_column='species')
     dm.load()
     self.assertTrue(
         'petal_length' in dm.X.columns,
         msg=f'Expected "petal_length" to be in X, but found {dm.X.columns}'
     )
Example #6
0
    def test_model_runner_process(self):
        """
        Test core process of loading data, fitting and making predictions using underlying model
        """
        # Define some model with it's data manager
        clf = HemlockRandomForestClassifier()
        clf.data_manager = DataManager(data_endpoint='test/iris.csv',
                                       target_column='species')
        clf.data_manager.load()

        # Pass model to ModelRunner
        runner = ModelRunner(clf)

        # Model isn't fitted, so it shouldn't be able to predict anything
        with self.assertRaises(
                NotFittedError,
                msg=
                "Model isn't fitted, so it shouldn't be able to predict anything!"
        ):
            runner.predict()

        # Fit & predict, ensuring that the orignal, runner, and predicted data sizes match
        runner.fit()
        data = runner.predict()
        original_size = clf.data_manager.X.shape[0]
        runner_size = runner.model.data_manager.X.shape[0]
        predicted_size = data.shape[0]
        self.assertTrue(
            original_size == runner_size == predicted_size, f'Expected: '
            f' original data ({original_size}) == runner data ({runner_size}) == predicted data ({predicted_size})'
        )
    def test_presigned_url(self):
        """Test fetching a presigned url to upload a dataset"""
        # Test we can generate presigned urls for GET and POST requests
        for action in ['GET', 'POST']:
            url = DataManager.generate_presigned_s3_url(
                bucket='hemlock-highway-test',
                key='customer1/data.csv',
                action=action)
            self.assertTrue(
                isinstance(url, str) and url.startswith('https://')
                and 'hemlock-highway-test' in url)

        # Raise ValueError on unavailable action
        with self.assertRaises(ValueError):
            DataManager.generate_presigned_s3_url(bucket='test',
                                                  key='something.csv',
                                                  action='FAIL')
 def test_load_from_s3(self):
     """
     Test the basic loading of a dataset.
     """
     dm = DataManager(data_endpoint='test/data/basic.csv',
                      target_column='species')
     self.assertFalse(
         dm._loaded,
         msg=
         'DataManger should not load data on initialization! Reports it is loaded!'
     )
     dm.load()
     self.assertTrue(
         dm._loaded,
         msg=
         'After asking to load data, DataManager is reporting it is not loaded!'
     )
     self.assertTrue(
         dm.X.shape[0] > 0,
         msg=
         'DataManager reports it is loaded, but does not have any data in X!'
     )
    def test_pickling(self):
        """Data manager should pickle, and drop any data held within it before pickling"""
        dm1 = DataManager(data_endpoint='test/data/basic.csv',
                          target_column='species')
        dm1.load()
        self.assertTrue(dm1.X.shape[0] > 0)  # Should be data in X
        with self.assertWarns(
                UserWarning):  # Should warn data is being dropped
            out = pickle.dumps(dm1)
        dm2 = pickle.loads(out)
        self.assertTrue(
            dm2.X.shape[0] == 0,  # Loaded object should be empty of data
            msg=
            f'After pickling and loading, DataManager had data in X: {dm2.X.shape}'
        )

        # Loaded object should have same attributes
        self.assertTrue(
            dm2.data_endpoint == dm1.data_endpoint,
            msg=
            'Original DataManager and loaded pickled version does not have same attributes!'
        )
Example #10
0
class HemlockModelBase:

    # Each model should have a DataManager to manage the handling of the IO/parsing of data for the model.
    data_manager = DataManager('', '')

    def __new__(cls, *args, **kwargs):
        cls.s3_client = boto3.client('s3',
                                     region_name=PROJECT_CONFIG.AWS_REGION)
        return super().__new__(cls)

    @staticmethod
    @abc.abstractstaticmethod
    def configurable_parameters():
        """
        Return a mapping of parameter names and submapping of type and valid values
        ie. {
            'n_estimators': {'type': int, 'range': }
        }
        """
        ...

    @abc.abstractmethod
    def dump(self, bucket: str, key: str):
        """
        Dump a model to s3
        """
        model_out = zlib.compress(pickle.dumps(self))
        self.s3_client.create_bucket(Bucket=bucket)
        resp = self.s3_client.put_object(Bucket=bucket,
                                         Key=key,
                                         Body=model_out)
        if resp['ResponseMetadata']['HTTPStatusCode'] == 200:
            return True
        else:
            raise IOError(
                f'Unable to serialize model to S3 location: {bucket}/{key}\nBoto3 response: {resp}'
            )

    @classmethod
    @abc.abstractmethod
    def load(cls, bucket: str, key: str):
        """
        Load a model from S3
        """
        model = cls().s3_client.get_object(Bucket=bucket,
                                           Key=f'{key}')['Body'].read()
        model = pickle.loads(zlib.decompress(model))
        return model
Example #11
0
 def test_load_model(self):
     """
     Test that a dumped model can be loaded from ModelRunner server
     """
     responses.add_passthru('https://')  # mock_s3 breaks requests
     clf = HemlockRandomForestClassifier()
     clf.data_manager = DataManager(
         data_endpoint=
         'https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv',
         target_column='species')
     clf.dump(bucket='test', key='models/test-model.pkl')
     resp = self.app.post(
         '/train-model',
         data={'model-location': 'test/models/test-model.pkl'})
     data = json.loads(resp.data)
     self.assertTrue(data.get('success'))