def get(self, table_name, *keys, **kwargs): num_uuids = self._table_names[table_name] ## A mapping between shards and columns to retrieve from each shard shards = defaultdict(list) ## Track the keys that we find found_keys = set() ## Determine all the shards that we need to contact for key in keys: assert len(key) == num_uuids joined_key = join_uuids(*key) row_names = [self._make_shard_name(table_name, joined_key)] for row_name in row_names: shards[row_name].append(joined_key) ## For each shard we contact, get all the keys it may hold for row_name, columns in shards.iteritems(): try: for key, val in self.tables[table_name].get(row_name, columns=columns,).iteritems(): key = split_uuids(key) logger.critical('cassandra get(%r) yielding %r %d' % (table_name, key, len(val))) ## key is a list [(uuid, uuid)] assert not any(k in found_keys for k in key) found_keys.add(tuple(key)) yield tuple(key), val except NotFoundException: pass ## Raise an exception if we don't retrieve all the keys that were requested missing_keys = set(keys) - found_keys for k in missing_keys: yield k, None
def scan(self, table_name, *key_ranges, **kwargs): kwargs.pop('batch_size', 100) if not key_ranges: ## get all columns key_ranges = [['', '']] # TODO: s/num_uuids/key_spec/g num_uuids = self._table_names[table_name] for start, finish in key_ranges: specific_key_range = bool( start or finish ) if specific_key_range and start == finish and len(start) == num_uuids: logger.warn('doing a scan on a single element, what?') #logger.info('specific_key_range: %r %r' % (start, finish)) assert len(start) == num_uuids joined_key = join_uuids(*start) columns = [joined_key] row_names = [self._make_shard_name(table_name, joined_key)] start = None finish = None else: columns = None # TODO: require_uuid and num_uuids is obsolete. fix. start = make_start_key(start, uuid_mode=self._require_uuid, num_uuids=num_uuids) finish = make_end_key(finish, uuid_mode=self._require_uuid, num_uuids=num_uuids) row_names = self._make_shard_names(table_name, start, finish) total_count = 0 hit_empty = False for row_name in row_names: try: for key, val in self._get_from_one_row(table_name, row_name, columns, start, finish, num_uuids): assert len(key) == num_uuids yield key, val if start: assert start <= join_uuids(*key) if finish: assert finish >= join_uuids(*key) total_count += 1 #logger.critical('total_count: %d' % total_count) except pycassa.NotFoundException: hit_empty = True
def delete(self, table_name, *keys, **kwargs): num_uuids = self._table_names[table_name] batch_size = kwargs.pop('batch_size', 1000) batch = self.tables[table_name].batch(queue_size=batch_size) count = 0 for key in keys: assert len(key) == num_uuids joined_key = join_uuids(*key) row_name = self._make_shard_name(table_name, joined_key) columns = [joined_key] #logger.critical('C* delete: table_name=%r columns=%r' % (table_name, columns)) batch.remove(row_name, columns=columns) count += 1 batch.send() logger.info('deleted %d tree_ids from %r' % (count, table_name))
def _get_from_one_row(self, table_name, row_name, columns, start, finish, num_uuids): logger.debug('c* get: table_name=%r row_name=%r columns=%r start=%r finish=%r' % ( table_name, row_name, columns, start, finish)) if not columns: assert start is not None and finish is not None assert start <= finish num_yielded = 0 while True: ## if we have prev_start = start logger.debug('cassandra get(%r...)' % row_name) ## if table_name == 'inbound': ## import ipdb ## ipdb.set_trace() for key, val in self.tables[table_name].get( row_name, columns=columns, column_start=start, column_finish=finish, column_count=1, ).iteritems(): key = split_uuids(key) logger.critical('cassandra get(%r) yielding %r %d' % (table_name, key, len(val))) yield key, val num_yielded += 1 logger.debug('c* get: table_name=%r row_name=%r columns=%r start=%r finish=%r' % ( table_name, row_name, columns, start, finish)) ## prepare to page ahead to next batch if columns: break start = list(key) start[-1] = uuid.UUID(int=key[-1].int+1) assert len(start) == num_uuids start = join_uuids(*start) if start == prev_start or start > finish: break logger.debug('paging forward from %r to %r' % (prev_start, start)) ## We need to raise a not found exception if the caller asked for ## a specific column and we didn't yield any results if not columns and num_yielded == 0: raise pycassa.NotFoundException
def put(self, table_name, *keys_and_values, **kwargs): batch_size = kwargs.pop('batch_size', None) tot_bytes = 0 cur_bytes = 0 tot_rows = 0 cur_rows = 0 num_uuids = self._table_names[table_name] start = time.time() logger.debug('starting save') batch = self.tables[table_name].batch(queue_size=batch_size) for key, blob in keys_and_values: self.check_put_key_value(key, blob, table_name, num_uuids) if len(blob) + cur_bytes >= self.thrift_framed_transport_size_in_mb * 2**19: logger.critical('len(blob)=%d + cur_bytes=%d >= thrift_framed_transport_size_in_mb/2 = %d' % (len(blob), cur_bytes, self.thrift_framed_transport_size_in_mb * 2**19)) ## convert to MB and then cut in half if cur_rows > 0: logger.critical('pre-emptively sending only what has been batched, and will send this item in next batch.') batch.send() cur_bytes = 0 cur_rows = 0 cur_bytes += len(blob) tot_bytes += len(blob) cur_rows += 1 tot_rows += 1 if not isinstance(key, tuple): ## for consistency, always make things join_uuids, ## even if only one key = (key,) joined_key = join_uuids(*key) row_name = self._make_shard_name(table_name, joined_key) if len(blob) >= self.thrift_framed_transport_size_in_mb * 2**19: logger.critical('len(blob)=%d >= thrift_framed_transport_size_in_mb / 2 = %d, so there is a risk that the total payload will exceed the full thrift_framed_transport_size_in_mb, and the only solution to this is to change Cassandra server-side config to allow larger frames...' % (len(blob), self.thrift_framed_transport_size_in_mb * 2**19)) batch.insert(row_name, {joined_key: blob}) #logger.critical('saving %s %r %r' % (table_name, key, blob)) if tot_rows % 500 == 0: logger.debug('num rows=%d, num MB=%d, thrift_framed_transport_size_in_mb=%d' % ( tot_rows, float(tot_bytes) / 2**20, self.thrift_framed_transport_size_in_mb)) batch.send() elapsed = time.time() - start row_rate = float(tot_rows) / elapsed MB_rate = float(tot_bytes) / elapsed / 2**20 logger.info('%s.insert(%d rows, %d bytes in %.1f sec --> %.1f rows/sec %.3f MBps' % ( table_name, tot_rows, tot_bytes, elapsed, row_rate, MB_rate))