def test_create_original_split_fact_value_given(self): payload = { "description": "Original index splitting", "indices": [{ "name": self.test_index_name }], "train_index": INDEX_SPLITTING_TRAIN_INDEX, "test_index": INDEX_SPLITTING_TEST_INDEX, "distribution": "original", "test_size": 20, "fact": self.FACT, "str_val": "FUBAR" } response = self.client.post(self.url, data=payload, format="json") print_output( 'test_create_original_split_fact_value_given:response.data', response.data) splitter_obj = IndexSplitter.objects.get(id=response.data['id']) sleep(5) original_distribution = ElasticAggregator( indices=self.test_index_name).get_fact_values_distribution( self.FACT) test_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TEST_INDEX).get_fact_values_distribution( self.FACT) train_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TRAIN_INDEX).get_fact_values_distribution( self.FACT) print_output( 'original_dist, test_dist, train_dist', [original_distribution, test_distribution, train_distribution]) for label, quant in original_distribution.items(): if label == "FUBAR": self.assertTrue( self.is_between_limits(test_distribution[label], quant, 0.2)) self.assertTrue( self.is_between_limits(train_distribution[label], quant, 0.8))
def test_create_equal_split(self): payload = { "description": "Original index splitting", "indices": [{ "name": self.test_index_name }], "train_index": INDEX_SPLITTING_TRAIN_INDEX, "test_index": INDEX_SPLITTING_TEST_INDEX, "distribution": "equal", "test_size": 20, "fact": self.FACT } response = self.client.post(self.url, data=payload) print_output('test_create_equal_split:response.data', response.data) splitter_obj = IndexSplitter.objects.get(id=response.data['id']) # Assert Task gets completed self.assertEqual(Task.STATUS_COMPLETED, Task.STATUS_COMPLETED) print_output("Task status", Task.STATUS_COMPLETED) sleep(5) original_distribution = ElasticAggregator( indices=self.test_index_name).get_fact_values_distribution( self.FACT) test_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TEST_INDEX).get_fact_values_distribution( self.FACT) train_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TRAIN_INDEX).get_fact_values_distribution( self.FACT) print_output( 'original_dist, test_dist, train_dist', [original_distribution, test_distribution, train_distribution]) for label, quant in original_distribution.items(): if (quant > 20): self.assertEqual(test_distribution[label], 20) self.assertEqual(train_distribution[label], quant - 20) else: self.assertEqual(test_distribution[label], quant) self.assertTrue(label not in train_distribution)