コード例 #1
0
ファイル: datastoreio.py プロジェクト: yaoshi1994/beam
    def process(self, query, *args, **kwargs):
      client = helper.get_client(query.project, query.namespace)
      try:
        # Short circuit estimating num_splits if split is not possible.
        query_splitter.validate_split(query)

        if self._num_splits == 0:
          estimated_num_splits = self.get_estimated_num_splits(client, query)
        else:
          estimated_num_splits = self._num_splits

        _LOGGER.info("Splitting the query into %d splits", estimated_num_splits)
        query_splits = query_splitter.get_splits(
            client, query, estimated_num_splits)
      except query_splitter.QuerySplitterError:
        _LOGGER.info("Unable to parallelize the given query: %s", query,
                     exc_info=True)
        query_splits = [query]

      return query_splits
コード例 #2
0
 def test_get_splits_query_with_num_splits_of_one(self):
     query = self.create_query()
     with self.assertRaisesRegexp(self.split_error, r'num_splits'):
         query_splitter.get_splits(None, query, 1)
コード例 #3
0
    def check_get_splits(self, query, num_splits, num_entities,
                         unused_batch_size):
        """A helper method to test the query_splitter get_splits method.

    Args:
      query: the query to be split
      num_splits: number of splits
      num_entities: number of scatter entities returned to the splitter.
      unused_batch_size: ignored in v1new since query results are entirely
        handled by the Datastore client.
    """
        # Test for random long ids, string ids, and a mix of both.
        for id_or_name in [True, False, None]:
            if id_or_name is None:
                client_entities = helper.create_client_entities(
                    num_entities, False)
                client_entities.extend(
                    helper.create_client_entities(num_entities, True))
                num_entities *= 2
            else:
                client_entities = helper.create_client_entities(
                    num_entities, id_or_name)

            mock_client = mock.MagicMock()
            mock_client_query = mock.MagicMock()
            mock_client_query.fetch.return_value = client_entities
            with mock.patch.object(types.Query,
                                   '_to_client_query',
                                   return_value=mock_client_query):
                split_queries = query_splitter.get_splits(
                    mock_client, query, num_splits)

            mock_client_query.fetch.assert_called_once()
            # if request num_splits is greater than num_entities, the best it can
            # do is one entity per split.
            expected_num_splits = min(num_splits, num_entities + 1)
            self.assertEqual(len(split_queries), expected_num_splits)

            # Verify no gaps in key ranges. Filters should look like:
            # query1: (__key__ < key1)
            # query2: (__key__ >= key1), (__key__ < key2)
            # ...
            # queryN: (__key__ >=keyN-1)
            prev_client_key = None
            last_query_seen = False
            for split_query in split_queries:
                self.assertFalse(last_query_seen)
                lt_key = None
                gte_key = None
                for _filter in split_query.filters:
                    self.assertEqual(query_splitter.KEY_PROPERTY_NAME,
                                     _filter[0])
                    if _filter[1] == '<':
                        lt_key = _filter[2]
                    elif _filter[1] == '>=':
                        gte_key = _filter[2]

                # Case where the scatter query has no results.
                if lt_key is None and gte_key is None:
                    self.assertEqual(1, len(split_queries))
                    break

                if prev_client_key is None:
                    self.assertIsNone(gte_key)
                    self.assertIsNotNone(lt_key)
                    prev_client_key = lt_key
                else:
                    self.assertEqual(prev_client_key, gte_key)
                    prev_client_key = lt_key
                    if lt_key is None:
                        last_query_seen = True