示例#1
0
 def test_cvscore_post(self):
     examples_file, label_file = _create_dataset()
     ds, _ = DataSet.objects.get_or_create(
         name='TEST',
         examples=SimpleUploadedFile(examples_file.name,
                                     examples_file.read()),
         labels=SimpleUploadedFile(label_file.name, label_file.read()))
     gs_1 = ATGridSearchCV(ensemble.RandomForestClassifier(), {
         'criterion': ['gini', 'entropy'],
         'max_depth': range(1, 21),
         'max_features': ['auto', 'log2', 'sqrt', None]
     },
                           dataset=ds.pk,
                           webserver_url=self.live_server_url)
     params = {'criterion': 'gini', 'max_depth': 3, 'max_features': 'log2'}
     res = fit_and_save(ensemble.RandomForestClassifier(**params),
                        X=numpy.genfromtxt(ds.examples, delimiter=','),
                        y=numpy.genfromtxt(ds.labels, delimiter=','),
                        parameters=params,
                        uuid=gs_1._uuid,
                        url=gs_1.webserver_url)
     self.assertEqual(res.status_code, 201)
     self.assertEqual(CVResult.objects.count(), 1)
     self.assertEqual(
         CVResult.objects.get(id=res.json()['id']).scores.all().count(), 1)
示例#2
0
 def test_cvscore_post_multimetric(self):
     from sklearn.metrics import make_scorer, accuracy_score, jaccard_similarity_score
     examples_file, label_file = _create_dataset()
     ds, _ = DataSet.objects.get_or_create(
         name='TEST',
         examples=SimpleUploadedFile(examples_file.name,
                                     examples_file.read()),
         labels=SimpleUploadedFile(label_file.name, label_file.read()))
     gs_1 = ATGridSearchCV(tree.DecisionTreeClassifier, {
         'criterion': ['gini', 'entropy'],
         'max_depth': range(1, 21),
         'max_features': ['auto', 'log2', 'sqrt', None]
     },
                           dataset=ds.pk,
                           webserver_url=self.live_server_url,
                           scoring={
                               'Accuracy': make_scorer(accuracy_score),
                               'Jaccard':
                               make_scorer(jaccard_similarity_score)
                           })
     params = {'criterion': 'gini', 'max_depth': 3, 'max_features': 'log2'}
     res = fit_and_save(tree.DecisionTreeClassifier(**params),
                        X=numpy.genfromtxt(ds.examples, delimiter=','),
                        y=numpy.genfromtxt(ds.labels, delimiter=','),
                        parameters=params,
                        uuid=gs_1._uuid,
                        url=gs_1.webserver_url,
                        scoring={
                            'Accuracy': make_scorer(accuracy_score),
                            'Jaccard': make_scorer(jaccard_similarity_score)
                        })
     self.assertEqual(res.status_code, 201)
     self.assertEqual(CVResult.objects.count(), 1)
     self.assertEqual(
         CVResult.objects.get(id=res.json()['id']).scores.all().count(), 2)
 def test_grid_search_model_creation(self):
     reg = linear_model.LinearRegression()
     examples_file , label_file = _create_dataset()
     ds, _ = DataSet.objects.get_or_create(name='TEST', 
                                           examples=SimpleUploadedFile(examples_file.name, examples_file.read()),
                                           labels=SimpleUploadedFile(label_file.name, label_file.read()))
     GridSearch.objects.get_or_create(classifier=reg.__class__.__name__, dataset=ds)
示例#4
0
    def test_dataset_grids_get(self):
        reg = linear_model.LinearRegression()
        examples_file, label_file = _create_dataset()
        ds, _ = DataSet.objects.get_or_create(
            name='TEST',
            examples=SimpleUploadedFile(examples_file.name,
                                        examples_file.read()),
            labels=SimpleUploadedFile(label_file.name, label_file.read()))

        gs_, _ = GridSearch.objects.get_or_create(
            classifier=reg.__class__.__name__, dataset=ds)
        client = DjangoClient()
        response = client.get(reverse('dataset_grids', kwargs={'name':
                                                               'TEST'}))
        self.assertEqual(200, response.status_code)
        self.assertEqual(1, len(response.data))
        gs_1 = ATGridSearchCV(ensemble.RandomForestClassifier(), {
            'criterion': ['gini', 'entropy'],
            'max_depth': range(1, 21),
            'max_features': ['auto', 'log2', 'sqrt', None]
        },
                              dataset=ds.pk,
                              webserver_url=self.live_server_url)
        wait(gs_1.fit())
        response = client.get(reverse('dataset_grids', kwargs={'name':
                                                               'TEST'}))
        self.assertEqual(200, response.status_code)
        self.assertEqual(2, len(response.data))
示例#5
0
    def test_atgridsearch_post(self):
        examples_file, label_file = _create_dataset()
        ds, _ = DataSet.objects.get_or_create(
            name='TEST',
            examples=SimpleUploadedFile(examples_file.name,
                                        examples_file.read()),
            labels=SimpleUploadedFile(label_file.name, label_file.read()))
        post_data = {
            'clf': tree.DecisionTreeClassifier.__name__,
            'dataset': ds.name
        }
        post_data['args'] = {
            'criterion': 'gini, entropy',
            'max_features': {
                'start': 5,
                'end': 10,
                'skip': 1
            },
            'presort': ['True', 'False']
        }

        response = DjangoClient().post(reverse('gridsearch_create'),
                                       json.dumps(post_data),
                                       content_type="application/json")
        self.assertEqual(201, response.status_code, response.data)
示例#6
0
 def test_dataset_post_missing_labels(self):
     examples_file, label_file = _create_dataset()
     client = DjangoClient()
     response = client.post(reverse('datasets'),
                            data={
                                'dataset': 'TEST',
                                'file[0]': examples_file
                            })
     self.assertEqual(400, response.status_code)
     self.assertEqual(b'"Missing dataset files"', response.content)
示例#7
0
 def test_dataset_post_success(self):
     examples_file, label_file = _create_dataset()
     client = DjangoClient()
     response = client.post(reverse('datasets'),
                            data={
                                'dataset': 'TEST',
                                'file[0]': examples_file,
                                'file[1]': label_file
                            })
     self.assertEqual(201, response.status_code)
     self.assertEqual(3, len(response.data))
     self.assertEqual(1, DataSet.objects.count())
示例#8
0
    def test_dataset_get(self):
        examples_file, label_file = _create_dataset()
        ds, _ = DataSet.objects.get_or_create(
            name='TEST',
            examples=SimpleUploadedFile(examples_file.name,
                                        examples_file.read()),
            labels=SimpleUploadedFile(label_file.name, label_file.read()))

        client = DjangoClient()
        response = client.get(reverse('datasets'))
        self.assertEqual(200, response.status_code)
        self.assertEqual(1, len(response.data))
示例#9
0
 def test_dataset_post_labels_bad_name(self):
     examples_file, label_file = _create_dataset()
     label_file.name = 'EXAMPLES.csv'
     client = DjangoClient()
     response = client.post(reverse('datasets'),
                            data={
                                'dataset': 'TEST',
                                'file[0]': examples_file,
                                'file[1]': label_file
                            })
     self.assertEqual(400, response.status_code)
     self.assertEqual(b'"Bad name of labels file"', response.content)
示例#10
0
 def test_dataset_post_exceed_files(self):
     examples_file, label_file = _create_dataset()
     client = DjangoClient()
     response = client.post(reverse('datasets'),
                            data={
                                'dataset': 'TEST',
                                'file[0]': examples_file,
                                'file[1]': label_file,
                                'file[2]': examples_file
                            })
     self.assertEqual(400, response.status_code)
     self.assertEqual(b'"Too many files"', response.content)
示例#11
0
 def test_dataset_model_single_file(self):
     examples_file, label_file = _create_dataset()
     ds, _ = DataSet.objects.get_or_create(name='TEST', 
                                           examples=SimpleUploadedFile(examples_file.name, examples_file.read()),
                                           labels=SimpleUploadedFile(label_file.name, label_file.read()))
     self.assertEqual('datasets/TEST/examples.csv', ds.examples.name)
     self.assertEqual('datasets/TEST/labels.csv', ds.labels.name)
     loaded_train = numpy.genfromtxt(ds.examples, delimiter=',')
     loaded_labels = numpy.genfromtxt(ds.labels, delimiter=',')
     iris = load_iris()
     self.assertTrue(numpy.array_equal(loaded_train, iris.data))
     self.assertTrue(numpy.array_equal(loaded_labels, iris.target))
示例#12
0
 def test_dataset_post_duplicate_name(self):
     examples_file, label_file = _create_dataset()
     client = DjangoClient()
     response = client.post(reverse('datasets'),
                            data={
                                'dataset': 'TEST',
                                'file[0]': examples_file,
                                'file[1]': label_file
                            })
     self.assertEqual(201, response.status_code)
     response = client.post(reverse('datasets'),
                            data={
                                'dataset': 'TEST',
                                'file[0]': examples_file,
                                'file[1]': label_file
                            })
     self.assertEqual(400, response.status_code)
     self.assertEqual(b'"Name already exists"', response.content)
 def test_ATGridSearchCV_with_dataset(self):
     examples, labels = _create_dataset()
     ds, _ = DataSet.objects.get_or_create(
         name='TEST',
         examples=SimpleUploadedFile(examples.name, examples.read()),
         labels=SimpleUploadedFile(labels.name, labels.read()))
     grid_size = 2 * 20 * 4
     gs = ATGridSearchCV(tree.DecisionTreeClassifier(), {
         'criterion': ['gini', 'entropy'],
         'max_depth': range(1, 21),
         'max_features': ['auto', 'log2', 'sqrt', None]
     },
                         dataset=ds.pk,
                         webserver_url=self.live_server_url)
     wait(gs.fit())
     self.assertAlmostEqual(
         grid_size,
         GridSearch.objects.get(uuid=gs._uuid).results.count(),
         delta=5)
示例#14
0
 def test_dataset_grid_results(self):
     examples, labels = _create_dataset()
     ds, _ = DataSet.objects.get_or_create(
         name='TEST',
         examples=SimpleUploadedFile(examples.name, examples.read()),
         labels=SimpleUploadedFile(labels.name, labels.read()))
     gs = ATGridSearchCV(tree.DecisionTreeClassifier(), {
         'criterion': ['gini', 'entropy'],
         'max_depth': range(1, 21),
         'max_features': ['auto', 'log2', 'sqrt', None]
     },
                         dataset=ds.pk,
                         webserver_url=self.live_server_url)
     wait(gs.fit())
     client = DjangoClient()
     response = client.get(
         reverse('grid_results', kwargs={'uuid': gs._uuid}))
     self.assertEqual(200, response.status_code)
     self.assertEqual(
         GridSearch.objects.get(uuid=gs._uuid).results.all().count(),
         len(response.data))
示例#15
0
    def test_atgridsearch_post_no_clf(self):
        examples_file, label_file = _create_dataset()
        ds, _ = DataSet.objects.get_or_create(
            name='TEST',
            examples=SimpleUploadedFile(examples_file.name,
                                        examples_file.read()),
            labels=SimpleUploadedFile(label_file.name, label_file.read()))
        post_data = {'clf': 'Tree', 'dataset': ds.name}
        post_data['args'] = {
            'criterion': 'gini, entropy',
            'max_features': {
                'start': 5,
                'end': 10,
                'skip': 1
            }
        }

        response = DjangoClient().post(reverse('gridsearch_create'),
                                       json.dumps(post_data),
                                       content_type="application/json")
        self.assertEqual(400, response.status_code)
        self.assertEqual(b'"No sklearn classifier named Tree"',
                         response.content)