示例#1
0
    def test_display_data(self, *args):
        dd_sql = ReadFromSpanner(project_id=TEST_PROJECT_ID,
                                 instance_id=TEST_INSTANCE_ID,
                                 database_id=_generate_database_name(),
                                 sql="Select * from users").display_data()

        dd_table = ReadFromSpanner(project_id=TEST_PROJECT_ID,
                                   instance_id=TEST_INSTANCE_ID,
                                   database_id=_generate_database_name(),
                                   table="users",
                                   columns=['id', 'name']).display_data()

        dd_transaction = ReadFromSpanner(project_id=TEST_PROJECT_ID,
                                         instance_id=TEST_INSTANCE_ID,
                                         database_id=_generate_database_name(),
                                         table="users",
                                         columns=['id', 'name'],
                                         transaction={
                                             "transaction_id": "test123",
                                             "session_id": "test456"
                                         }).display_data()

        self.assertTrue("sql" in dd_sql)
        self.assertTrue("table" in dd_table)
        self.assertTrue("table" in dd_transaction)
        self.assertTrue("transaction" in dd_transaction)
示例#2
0
  def test_read_with_query_batch(self, mock_batch_snapshot_class,
                                 mock_client_class):
    mock_snapshot = mock.MagicMock()

    mock_snapshot.generate_query_batches.return_value = [
        {'query': {'sql': 'SELECT * FROM users'},
         'partition': 'test_partition'} for _ in range(3)]
    mock_snapshot.process_query_batch.side_effect = [
        FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]]

    ro = [ReadOperation.query("Select * from users")]
    pipeline = TestPipeline()

    read = (pipeline
            | 'read' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID,
                                        _generate_database_name(),
                                        sql="SELECT * FROM users"))

    readall = (pipeline
               | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID,
                                               TEST_INSTANCE_ID,
                                               _generate_database_name(),
                                               read_operations=ro))

    readpipeline = (pipeline
                    | 'create reads' >> beam.Create(ro)
                    | 'reads' >> ReadFromSpanner(TEST_PROJECT_ID,
                                                 TEST_INSTANCE_ID,
                                                 _generate_database_name()))

    pipeline.run()
    assert_that(read, equal_to(FAKE_ROWS), label='checkRead')
    assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll')
    assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
示例#3
0
    def test_read_with_table_batch(self, mock_batch_snapshot_class,
                                   mock_client_class):
        mock_snapshot = mock.MagicMock()
        mock_snapshot.generate_read_batches.return_value = [{
            'read': {
                'table': 'users',
                'keyset': {
                    'all': True
                },
                'columns': ['Key', 'Value'],
                'index': ''
            },
            'partition':
            'test_partition'
        } for _ in range(3)]
        mock_snapshot.process_read_batch.side_effect = [
            FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]
        ]

        ro = [ReadOperation.table("users", ["Key", "Value"])]
        pipeline = TestPipeline()

        read = (pipeline
                | 'read' >> ReadFromSpanner(TEST_PROJECT_ID,
                                            TEST_INSTANCE_ID,
                                            _generate_database_name(),
                                            table="users",
                                            columns=["Key", "Value"]))

        readall = (pipeline
                   | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID,
                                                   TEST_INSTANCE_ID,
                                                   _generate_database_name(),
                                                   read_operations=ro))

        readpipeline = (
            pipeline
            | 'create reads' >> beam.Create(ro)
            | 'reads' >> ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID,
                                         _generate_database_name()))

        pipeline.run()
        assert_that(read, equal_to(FAKE_ROWS), label='checkRead')
        assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll')
        assert_that(readpipeline,
                    equal_to(FAKE_ROWS),
                    label='checkReadPipeline')

        with self.assertRaises(ValueError):
            # Test the exception raised when user passes the read operations in the
            # constructor and also in the pipeline.
            _ = (pipeline | 'reads error' >> ReadFromSpanner(
                project_id=TEST_PROJECT_ID,
                instance_id=TEST_INSTANCE_ID,
                database_id=_generate_database_name(),
                table="users"))
            pipeline.run()
示例#4
0
 def test_read_via_sql(self):
     _LOGGER.info("Running Spanner via sql")
     with beam.Pipeline(argv=self.args) as p:
         r = p | ReadFromSpanner(self.project,
                                 self.instance,
                                 self.TEST_DATABASE,
                                 sql="select * from Users")
     assert_that(r, equal_to(self._data))
示例#5
0
 def test_read_via_table(self):
     _LOGGER.info("Spanner Read via table")
     with beam.Pipeline(argv=self.args) as p:
         r = p | ReadFromSpanner(self.project,
                                 self.instance,
                                 self.TEST_DATABASE,
                                 table="Users",
                                 columns=["UserId", "Key"])
     assert_that(r, equal_to(self._data))
示例#6
0
 def test_invalid_transaction(self, mock_batch_snapshot_class,
                              mock_client_class):
     with self.assertRaises(ValueError):
         p = TestPipeline()
         transaction = (p | beam.Create([{"invalid": "transaction"}]))
         _ = (p | 'with query' >> ReadFromSpanner(
             project_id=TEST_PROJECT_ID,
             instance_id=TEST_INSTANCE_ID,
             database_id=_generate_database_name(),
             transaction=transaction,
             sql="Select * from users"))
         p.run()
示例#7
0
 def test(self):
     output = (
         self.pipeline
         | 'Read from Spanner' >> ReadFromSpanner(
             self.project,
             self.spanner_instance,
             self.spanner_database,
             sql="select data from test_data")
         | 'Count messages' >> ParDo(CountMessages(self.metrics_namespace))
         | 'Measure time' >> ParDo(MeasureTime(self.metrics_namespace))
         | 'Count' >> Count.Globally())
     assert_that(output, equal_to([self.input_options['num_records']]))
    def test_sql_metrics_ok_call(self):
        if 'DirectRunner' not in self.runner_name:
            raise unittest.SkipTest('This test only runs with DirectRunner.')

        MetricsEnvironment.process_wide_container().reset()

        with beam.Pipeline(argv=self.args) as p:
            r = p | ReadFromSpanner(self.project,
                                    self.instance,
                                    self.TEST_DATABASE,
                                    sql="select * from Users",
                                    query_name='query-1')

        assert_that(r, equal_to(self._data))
        self.verify_sql_read_call_metric(self.project, self.TEST_DATABASE,
                                         'query-1', 'ok', 1)
    def test_table_metrics_ok_call(self):
        if 'DirectRunner' not in self.runner_name:
            raise unittest.SkipTest('This test only runs with DirectRunner.')

        MetricsEnvironment.process_wide_container().reset()

        with beam.Pipeline(argv=self.args) as p:
            r = p | ReadFromSpanner(self.project,
                                    self.instance,
                                    self.TEST_DATABASE,
                                    table="Users",
                                    columns=["UserId", "Key"])

        assert_that(r, equal_to(self._data))
        self.verify_table_read_call_metric(self.project, self.TEST_DATABASE,
                                           'Users', 'ok', 1)
    def test_sql_metrics_error_call(self):
        if 'DirectRunner' not in self.runner_name:
            raise unittest.SkipTest('This test only runs with DirectRunner.')

        MetricsEnvironment.process_wide_container().reset()

        with self.assertRaises(Exception):
            p = beam.Pipeline(argv=self.args)
            _ = p | ReadFromSpanner(self.project,
                                    self.instance,
                                    self.TEST_DATABASE,
                                    sql="select * from NonExistent",
                                    query_name='query-2')

            res = p.run()
            res.wait_until_finish()

            self.verify_sql_read_call_metric(self.project, self.TEST_DATABASE,
                                             'query-2', '400', 1)
示例#11
0
  def test_table_metrics_error_call(self):
    if 'DirectRunner' not in self.runner_name:
      raise unittest.SkipTest('This test only runs with DirectRunner.')

    MetricsEnvironment.process_wide_container().reset()

    with self.assertRaises(Exception):
      p = beam.Pipeline(argv=self.args)
      _ = p | ReadFromSpanner(
          self.project,
          self.instance,
          self.TEST_DATABASE,
          table="INVALID_TABLE",
          columns=["UserId", "Key"])

      res = p.run()
      res.wait_until_finish()

      self.verify_table_read_call_metric(
          self.project, self.TEST_DATABASE, 'INVALID_TABLE', '404', 1)
示例#12
0
    def test_read_with_transaction(self, mock_batch_snapshot_class,
                                   mock_client_class):
        mock_client = mock.MagicMock()
        mock_instance = mock.MagicMock()
        mock_database = mock.MagicMock()
        mock_snapshot = mock.MagicMock()

        mock_client_class.return_value = mock_client
        mock_client.instance.return_value = mock_instance
        mock_instance.database.return_value = mock_database
        mock_database.batch_snapshot.return_value = mock_snapshot
        mock_batch_snapshot_class.return_value = mock_snapshot
        mock_batch_snapshot_class.from_dict.return_value = mock_snapshot
        mock_snapshot.to_dict.return_value = FAKE_TRANSACTION_INFO

        mock_session = mock.MagicMock()
        mock_transaction_ctx = mock.MagicMock()
        mock_transaction = mock.MagicMock()

        mock_snapshot._get_session.return_value = mock_session
        mock_session.transaction.return_value = mock_transaction
        mock_transaction.__enter__.return_value = mock_transaction_ctx
        mock_transaction_ctx.execute_sql.return_value = FAKE_ROWS

        ro = [ReadOperation.query("Select * from users")]
        p = TestPipeline()

        transaction = (p | create_transaction(
            project_id=TEST_PROJECT_ID,
            instance_id=TEST_INSTANCE_ID,
            database_id=_generate_database_name(),
            exact_staleness=datetime.timedelta(seconds=10)))

        read_query = (p | 'with query' >> ReadFromSpanner(
            project_id=TEST_PROJECT_ID,
            instance_id=TEST_INSTANCE_ID,
            database_id=_generate_database_name(),
            transaction=transaction,
            sql="Select * from users"))

        read_table = (p | 'with table' >> ReadFromSpanner(
            project_id=TEST_PROJECT_ID,
            instance_id=TEST_INSTANCE_ID,
            database_id=_generate_database_name(),
            transaction=transaction,
            table="users",
            columns=["Key", "Value"]))

        read_indexed_table = (p | 'with index' >> ReadFromSpanner(
            project_id=TEST_PROJECT_ID,
            instance_id=TEST_INSTANCE_ID,
            database_id=_generate_database_name(),
            transaction=transaction,
            table="users",
            index="Key",
            columns=["Key", "Value"]))

        read = (p | 'read all' >> ReadFromSpanner(TEST_PROJECT_ID,
                                                  TEST_INSTANCE_ID,
                                                  _generate_database_name(),
                                                  transaction=transaction,
                                                  read_operations=ro))

        read_pipeline = (p
                         | 'create read operations' >> beam.Create(ro)
                         |
                         'reads' >> ReadFromSpanner(TEST_PROJECT_ID,
                                                    TEST_INSTANCE_ID,
                                                    _generate_database_name(),
                                                    transaction=transaction))

        p.run()

        assert_that(read_query, equal_to(FAKE_ROWS), label='checkQuery')
        assert_that(read_table, equal_to(FAKE_ROWS), label='checkTable')
        assert_that(read_indexed_table,
                    equal_to(FAKE_ROWS),
                    label='checkTableIndex')
        assert_that(read, equal_to(FAKE_ROWS), label='checkReadAll')
        assert_that(read_pipeline,
                    equal_to(FAKE_ROWS),
                    label='checkReadPipeline')

        with self.assertRaises(ValueError):
            # Test the exception raised when user passes the read operations in the
            # constructor and also in the pipeline.
            _ = (p
                 | 'create read operations2' >> beam.Create(ro)
                 | 'reads with error' >> ReadFromSpanner(
                     TEST_PROJECT_ID,
                     TEST_INSTANCE_ID,
                     _generate_database_name(),
                     transaction=transaction,
                     read_operations=ro))
            p.run()