def process(self, query, *args, **kwargs): client = helper.get_client(query.project, query.namespace) try: # Short circuit estimating num_splits if split is not possible. query_splitter.validate_split(query) if self._num_splits == 0: estimated_num_splits = self.get_estimated_num_splits(client, query) else: estimated_num_splits = self._num_splits _LOGGER.info("Splitting the query into %d splits", estimated_num_splits) query_splits = query_splitter.get_splits( client, query, estimated_num_splits) except query_splitter.QuerySplitterError: _LOGGER.info("Unable to parallelize the given query: %s", query, exc_info=True) query_splits = [query] return query_splits
def test_get_splits_query_with_num_splits_of_one(self): query = self.create_query() with self.assertRaisesRegexp(self.split_error, r'num_splits'): query_splitter.get_splits(None, query, 1)
def check_get_splits(self, query, num_splits, num_entities, unused_batch_size): """A helper method to test the query_splitter get_splits method. Args: query: the query to be split num_splits: number of splits num_entities: number of scatter entities returned to the splitter. unused_batch_size: ignored in v1new since query results are entirely handled by the Datastore client. """ # Test for random long ids, string ids, and a mix of both. for id_or_name in [True, False, None]: if id_or_name is None: client_entities = helper.create_client_entities( num_entities, False) client_entities.extend( helper.create_client_entities(num_entities, True)) num_entities *= 2 else: client_entities = helper.create_client_entities( num_entities, id_or_name) mock_client = mock.MagicMock() mock_client_query = mock.MagicMock() mock_client_query.fetch.return_value = client_entities with mock.patch.object(types.Query, '_to_client_query', return_value=mock_client_query): split_queries = query_splitter.get_splits( mock_client, query, num_splits) mock_client_query.fetch.assert_called_once() # if request num_splits is greater than num_entities, the best it can # do is one entity per split. expected_num_splits = min(num_splits, num_entities + 1) self.assertEqual(len(split_queries), expected_num_splits) # Verify no gaps in key ranges. Filters should look like: # query1: (__key__ < key1) # query2: (__key__ >= key1), (__key__ < key2) # ... # queryN: (__key__ >=keyN-1) prev_client_key = None last_query_seen = False for split_query in split_queries: self.assertFalse(last_query_seen) lt_key = None gte_key = None for _filter in split_query.filters: self.assertEqual(query_splitter.KEY_PROPERTY_NAME, _filter[0]) if _filter[1] == '<': lt_key = _filter[2] elif _filter[1] == '>=': gte_key = _filter[2] # Case where the scatter query has no results. if lt_key is None and gte_key is None: self.assertEqual(1, len(split_queries)) break if prev_client_key is None: self.assertIsNone(gte_key) self.assertIsNotNone(lt_key) prev_client_key = lt_key else: self.assertEqual(prev_client_key, gte_key) prev_client_key = lt_key if lt_key is None: last_query_seen = True