Example #1
0
 def test_batch_read(self):
     # TODO(sougou): fix
     return
     # 1. Create select queries using DB classes.
     query_list = []
     bv_list = []
     user_id_list = [self.user_id_list[0], self.user_id_list[1]]
     where_column_value_pairs = (('id', user_id_list), )
     entity_id_map = dict(where_column_value_pairs)
     q, bv = db_class_sharded.VtUser.create_select_query(
         where_column_value_pairs)
     query_list.append(q)
     bv_list.append(bv)
     where_column_value_pairs = (('user_id', user_id_list), )
     q, bv = db_class_sharded.VtUserEmail.create_select_query(
         where_column_value_pairs)
     query_list.append(q)
     bv_list.append(bv)
     with database_context.ReadFromMaster(self.dc) as context:
         # 2. Cursor Creation using one of the DB classes.
         cursor = context.get_cursor(entity_id_map=entity_id_map)(
             db_class_sharded.VtUser)
         # 3. Batch execution of reads.
         results = db_object.execute_batch_read(cursor, query_list, bv_list)
         self.assertEqual(len(results), len(query_list))
         res_ids = [row.id for row in results[0]]
         res_user_ids = [row.user_id for row in results[1]]
         self.assertEqual(res_ids, user_id_list)
         self.assertEqual(res_user_ids, user_id_list)
Example #2
0
 def test_entity_id_read(self):
   with database_context.ReadFromMaster(self.dc) as context:
     entity_id_map = {'username': '******'}
     rows = db_class_sharded.VtUser.select_by_columns(
         context.get_cursor(entity_id_map=entity_id_map),
         [('id', self.user_id_list[0]),])
     self.assertEqual(len(rows), 1, "wrong number of rows fetched")
Example #3
0
  def test_in_clause_read(self):
    with database_context.ReadFromMaster(self.dc) as context:
      user_id_list = [self.user_id_list[0], self.user_id_list[1]]

      where_column_value_pairs = (('id', user_id_list),)
      entity_id_map = dict(where_column_value_pairs)
      rows = db_class_sharded.VtUser.select_by_ids(
          context.get_cursor(entity_id_map=entity_id_map),
          where_column_value_pairs)
      self.assertEqual(len(rows), 2, "wrong number of rows fetched")
      self.assertEqual(user_id_list, [row.id for row in rows], "wrong rows fetched")

      username_list = [row.username for row in rows]
      where_column_value_pairs = (('username', username_list),)
      entity_id_map = dict(where_column_value_pairs)
      rows = db_class_sharded.VtUser.select_by_ids(
          context.get_cursor(entity_id_map=entity_id_map),
          where_column_value_pairs)
      self.assertEqual(len(rows), 2, "wrong number of rows fetched")
      self.assertEqual(username_list, [row.username for row in rows], "wrong rows fetched")

      where_column_value_pairs = (('user_id', user_id_list),)
      entity_id_map = dict(where_column_value_pairs)
      rows = db_class_sharded.VtUserEmail.select_by_ids(
          context.get_cursor(entity_id_map=entity_id_map),
          where_column_value_pairs)
      self.assertEqual(len(rows), 2, "wrong number of rows fetched")
      self.assertEqual(user_id_list, [row.user_id for row in rows], "wrong rows fetched")
Example #4
0
    def test_entity_id_read(self):
        user_id = self.user_id_list[0]
        with database_context.ReadFromMaster(self.dc) as context:
            entity_id_map = {'username': '******'}
            rows = db_class_sharded.VtUser.select_by_columns(
                context.get_cursor(entity_id_map=entity_id_map), [
                    ('id', user_id),
                ])
            self.assertEqual(len(rows), 1, "wrong number of rows fetched")

            where_column_value_pairs = [
                ('id', self.user_song_map[user_id][0]),
            ]
            entity_id_map = dict(where_column_value_pairs)
            rows = db_class_sharded.VtSong.select_by_columns(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            self.assertEqual(len(rows), 1, "wrong number of rows fetched")

            where_column_value_pairs = [
                ('song_id', self.user_song_map[user_id][0]),
            ]
            entity_id_map = dict(where_column_value_pairs)
            rows = db_class_sharded.VtSongDetail.select_by_columns(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            self.assertEqual(len(rows), 1, "wrong number of rows fetched")
Example #5
0
    def test_sharding_key_read(self):
        user_id = self.user_id_list[0]
        with database_context.ReadFromMaster(self.dc) as context:
            where_column_value_pairs = [
                ('id', user_id),
            ]
            entity_id_map = dict(where_column_value_pairs)
            rows = db_class_sharded.VtUser.select_by_columns(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            self.assertEqual(len(rows), 1, 'wrong number of rows fetched')

            where_column_value_pairs = [
                ('user_id', user_id),
            ]
            entity_id_map = dict(where_column_value_pairs)
            rows = db_class_sharded.VtUserEmail.select_by_columns(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            self.assertEqual(len(rows), 1, 'wrong number of rows fetched')

            where_column_value_pairs = [
                ('user_id', user_id),
            ]
            entity_id_map = dict(where_column_value_pairs)
            rows = db_class_sharded.VtSong.select_by_columns(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            self.assertEqual(len(rows), len(self.user_song_map[user_id]),
                             'wrong number of rows fetched')
Example #6
0
 def test_scatter_read(self):
   where_column_value_pairs = []
   with database_context.ReadFromMaster(self.dc) as context:
     rows = db_class_sharded.VtUser.select_by_columns(
         context.get_cursor(keyrange=keyrange_constants.NON_PARTIAL_KEYRANGE),
         where_column_value_pairs)
     self.assertEqual(len(rows), len(self.user_id_list), "wrong number of rows fetched, expecting %d got %d" % (len(self.user_id_list), len(rows)))
Example #7
0
    def delete_columns(self):
        user_id = self.user_id_list[-1]
        with database_context.WriteTransaction(self.dc) as context:
            where_column_value_pairs = [
                ('id', user_id),
            ]
            entity_id_map = {'id': user_id}
            db_class_sharded.VtUser.delete_by_columns(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)

            where_column_value_pairs = [
                ('user_id', user_id),
            ]
            entity_id_map = {'user_id': user_id}
            db_class_sharded.VtUserEmail.delete_by_columns(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)

        with database_context.ReadFromMaster(self.dc) as context:
            rows = db_class_sharded.VtUser.select_by_columns(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            self.assertEqual(len(rows), 0, "wrong number of rows fetched")

            rows = db_class_sharded.VtUserEmail.select_by_ids(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            self.assertEqual(len(rows), 0, "wrong number of rows fetched")
        self.user_id_list = self.user_id_list[:-1]
        self.user_id_list.sort()
Example #8
0
 def test_read(self):
   with database_context.ReadFromMaster(self.dc) as context:
     rows = db_class_unsharded.VtUnsharded.select_by_id(
         context.get_cursor(), 2)
     self.assertEqual(len(rows), 1, "wrong number of rows fetched")
     for row in rows:
       logging.info("ROW: %s" % row)
     self.assertEqual(rows[0].id, 2, "wrong row fetched")
Example #9
0
 def test_count(self):
     with database_context.ReadFromMaster(self.dc) as context:
         count = db_class_unsharded.VtUnsharded.get_count(
             context.get_cursor(), msg="test message")
         expected = len(self.all_ids)
         self.assertEqual(
             count, expected,
             "wrong count fetched; expected %d got %d" % (expected, count))
Example #10
0
 def test_count(self):
     with database_context.ReadFromMaster(self.dc) as context:
         count = db_class_sharded.VtUser.get_count(context.get_cursor(
             keyrange=keyrange_constants.NON_PARTIAL_KEYRANGE),
                                                   msg="test message")
         expected = len(self.user_id_list)
         self.assertEqual(
             count, expected,
             "wrong count fetched; expected %d got %d" % (expected, count))
Example #11
0
  def test_delete_and_read(self):
    where_column_value_pairs = [('id', 2)]
    with database_context.WriteTransaction(self.dc) as context:
      db_class_unsharded.VtUnsharded.delete_by_columns(context.get_cursor(),
                                                    where_column_value_pairs)

    with database_context.ReadFromMaster(self.dc) as context:
      rows = db_class_unsharded.VtUnsharded.select_by_id(context.get_cursor(), 2)
      self.assertEqual(len(rows), 0, "wrong number of rows fetched")
Example #12
0
 def test_min_id(self):
     with database_context.ReadFromMaster(self.dc) as context:
         min_id = db_class_unsharded.VtUnsharded.get_min(
             context.get_cursor())
         expected = min(self.all_ids)
         self.assertEqual(
             min_id, expected,
             "wrong min value fetched; expected %d got %d" %
             (expected, min_id))
Example #13
0
 def test_max_id(self):
     with database_context.ReadFromMaster(self.dc) as context:
         max_id = db_class_unsharded.VtUnsharded.get_max(
             context.get_cursor())
         self.all_ids.sort()
         expected = max(self.all_ids)
         self.assertEqual(
             max_id, expected,
             'wrong max value fetched; expected %d got %d' %
             (expected, max_id))
Example #14
0
 def test_max_id(self):
     with database_context.ReadFromMaster(self.dc) as context:
         max_id = db_class_sharded.VtUser.get_max(
             context.get_cursor(
                 keyrange=keyrange_constants.NON_PARTIAL_KEYRANGE))
         expected = max(self.user_id_list)
         self.assertEqual(
             max_id, expected,
             "wrong max value fetched; expected %d got %d" %
             (expected, max_id))
Example #15
0
 def test_streaming_read(self):
   where_column_value_pairs = []
   with database_context.ReadFromMaster(self.dc) as context:
     rows = db_class_sharded.VtUser.select_by_columns_streaming(
         context.get_cursor(keyrange=keyrange_constants.NON_PARTIAL_KEYRANGE),
         where_column_value_pairs)
     got_user_id_list = []
     for r in rows:
       got_user_id_list.append(r.id)
     self.assertEqual(len(got_user_id_list), len(self.user_id_list), "wrong number of rows fetched")
Example #16
0
 def test_keyrange_read(self):
   where_column_value_pairs = []
   with database_context.ReadFromMaster(self.dc) as context:
     rows1 = db_class_sharded.VtUser.select_by_columns(
         context.get_cursor(keyrange='-80'), where_column_value_pairs)
     rows2 = db_class_sharded.VtUser.select_by_columns(
         context.get_cursor(keyrange='80-'), where_column_value_pairs)
     fetched_rows = len(rows1) + len(rows2)
     expected = len(self.user_id_list)
     self.assertEqual(fetched_rows, expected, "wrong number of rows fetched expected:%d got:%d" % (expected, fetched_rows))
Example #17
0
 def test_read(self):
     id_val = self.all_ids[0]
     with database_context.ReadFromMaster(self.dc) as context:
         rows = db_class_unsharded.VtUnsharded.select_by_id(
             context.get_cursor(), id_val)
         expected = 1
         self.assertEqual(
             len(rows), expected,
             "wrong number of rows fetched %d, expected %d" %
             (len(rows), expected))
         self.assertEqual(rows[0].id, id_val, "wrong row fetched")
Example #18
0
  def test_sharding_key_read(self):
    with database_context.ReadFromMaster(self.dc) as context:
      where_column_value_pairs = [('id', self.user_id_list[0]),]
      rows = db_class_sharded.VtUser.select_by_columns(
          context.get_cursor(entity_id_map={'id':self.user_id_list[0]}),
          where_column_value_pairs)
      self.assertEqual(len(rows), 1, "wrong number of rows fetched")

      where_column_value_pairs = [('user_id', self.user_id_list[0]),]
      rows = db_class_sharded.VtUserEmail.select_by_columns(
          context.get_cursor(entity_id_map={'user_id':self.user_id_list[0]}),
          where_column_value_pairs)
      self.assertEqual(len(rows), 1, "wrong number of rows fetched")
Example #19
0
    def update_columns(self):
        with database_context.WriteTransaction(self.dc) as context:
            user_id = self.user_id_list[1]
            where_column_value_pairs = [
                ('id', user_id),
            ]
            entity_id_map = {'id': user_id}
            new_username = '******' % user_id
            update_cols = [
                ('username', new_username),
            ]
            db_class_sharded.VtUser.update_columns(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs,
                update_column_value_pairs=update_cols)
            # verify the updated value.
            where_column_value_pairs = [
                ('id', user_id),
            ]
            rows = db_class_sharded.VtUser.select_by_columns(
                context.get_cursor(entity_id_map={'id': user_id}),
                where_column_value_pairs)
            self.assertEqual(len(rows), 1, "wrong number of rows fetched")
            self.assertEqual(new_username, rows[0].username)

            where_column_value_pairs = [
                ('user_id', user_id),
            ]
            entity_id_map = {'user_id': user_id}
            new_email = '*****@*****.**' % user_id
            m = hashlib.md5()
            m.update(new_email)
            email_hash = m.digest()
            update_cols = [('email', new_email), ('email_hash', email_hash)]
            db_class_sharded.VtUserEmail.update_columns(
                context.get_cursor(entity_id_map={'user_id': user_id}),
                where_column_value_pairs,
                update_column_value_pairs=update_cols)

        # verify the updated value.
        with database_context.ReadFromMaster(self.dc) as context:
            where_column_value_pairs = [
                ('user_id', user_id),
            ]
            entity_id_map = dict(where_column_value_pairs)
            rows = db_class_sharded.VtUserEmail.select_by_ids(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            self.assertEqual(len(rows), 1, "wrong number of rows fetched")
            self.assertEqual(new_email, rows[0].email)
        self.user_id_list.sort()
Example #20
0
    def test_update_and_read(self):
        id_val = self.all_ids[0]
        where_column_value_pairs = [('id', id_val)]
        with database_context.WriteTransaction(self.dc) as context:
            db_class_unsharded.VtUnsharded.update_columns(
                context.get_cursor(),
                where_column_value_pairs,
                msg="test update")

        with database_context.ReadFromMaster(self.dc) as context:
            rows = db_class_unsharded.VtUnsharded.select_by_id(
                context.get_cursor(), id_val)
            self.assertEqual(len(rows), 1, "wrong number of rows fetched")
            self.assertEqual(rows[0].msg, "test update", "wrong row fetched")
Example #21
0
 def test_min_id(self):
     with database_context.ReadFromMaster(self.dc) as context:
         min_id = db_class_sharded.VtUser.get_min(
             context.get_cursor(
                 keyrange=keyrange_constants.NON_PARTIAL_KEYRANGE))
         self.user_id_list.sort()
         expected = min(self.user_id_list)
         rows1 = db_class_sharded.VtUser.select_by_columns(
             context.get_cursor(
                 keyrange=keyrange_constants.NON_PARTIAL_KEYRANGE), [])
         id_list = [row.id for row in rows1]
         self.assertEqual(
             min_id, expected,
             "wrong min value fetched; expected %d got %d" %
             (expected, min_id))
Example #22
0
 def test_batch_read(self):
     query_list = []
     bv_list = []
     user_id_list = [self.user_id_list[0], self.user_id_list[1]]
     where_column_value_pairs = (('id', user_id_list), )
     entity_id_map = dict(where_column_value_pairs)
     q, bv = db_class_sharded.VtUser.create_select_query(
         where_column_value_pairs)
     query_list.append(q)
     bv_list.append(bv)
     where_column_value_pairs = (('user_id', user_id_list), )
     q, bv = db_class_sharded.VtUserEmail.create_select_query(
         where_column_value_pairs)
     query_list.append(q)
     bv_list.append(bv)
     with database_context.ReadFromMaster(self.dc) as context:
         cursor = context.get_cursor(entity_id_map=entity_id_map)(
             db_class_sharded.VtUser)
         results = db_object.execute_batch_read(cursor, query_list, bv_list)
         self.assertEqual(len(results), len(query_list))
         res_ids = [row.id for row in results[0]]
         res_user_ids = [row.user_id for row in results[1]]
         self.assertEqual(res_ids, user_id_list)
         self.assertEqual(res_user_ids, user_id_list)
Example #23
0
    def test_in_clause_read(self):
        with database_context.ReadFromMaster(self.dc) as context:
            user_id_list = [self.user_id_list[0], self.user_id_list[1]]

            where_column_value_pairs = (('id', user_id_list), )
            entity_id_map = dict(where_column_value_pairs)
            rows = db_class_sharded.VtUser.select_by_ids(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            self.assertEqual(len(rows), 2, "wrong number of rows fetched")
            got = [row.id for row in rows]
            got.sort()
            self.assertEqual(
                user_id_list, got,
                "wrong rows fetched; expected %s got %s" % (user_id_list, got))

            username_list = [row.username for row in rows]
            username_list.sort()
            where_column_value_pairs = (('username', username_list), )
            entity_id_map = dict(where_column_value_pairs)
            rows = db_class_sharded.VtUser.select_by_ids(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            self.assertEqual(len(rows), 2, "wrong number of rows fetched")
            got = [row.username for row in rows]
            got.sort()
            self.assertEqual(
                username_list, got, "wrong rows fetched; expected %s got %s" %
                (username_list, got))

            where_column_value_pairs = (('user_id', user_id_list), )
            entity_id_map = dict(where_column_value_pairs)
            rows = db_class_sharded.VtUserEmail.select_by_ids(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            self.assertEqual(len(rows), 2, "wrong number of rows fetched")
            got = [row.user_id for row in rows]
            got.sort()
            self.assertEqual(
                user_id_list, got,
                "wrong rows fetched; expected %s got %s" % (user_id_list, got))

            song_id_list = []
            for user_id in user_id_list:
                song_id_list.extend(self.user_song_map[user_id])
            song_id_list.sort()
            where_column_value_pairs = [
                ('id', song_id_list),
            ]
            entity_id_map = dict(where_column_value_pairs)
            rows = db_class_sharded.VtSong.select_by_columns(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            got = [row.id for row in rows]
            got.sort()
            self.assertEqual(
                song_id_list, got,
                "wrong rows fetched %s got %s" % (song_id_list, got))

            where_column_value_pairs = [
                ('song_id', song_id_list),
            ]
            entity_id_map = dict(where_column_value_pairs)
            rows = db_class_sharded.VtSongDetail.select_by_columns(
                context.get_cursor(entity_id_map=entity_id_map),
                where_column_value_pairs)
            got = [row.song_id for row in rows]
            got.sort()
            self.assertEqual(
                song_id_list, got,
                "wrong rows fetched %s got %s" % (song_id_list, got))