def __init__(self): self.start = None self.stop = None self.txn_id = 0 self.opCount = 0 self.completed = [] # (txnName, timestamp) self.txn_counters = Histogram() self.txn_times = {} self.running = {}
def __init__(self, collections, num_nodes): assert isinstance(collections, dict) # LOG.setLevel(logging.DEBUG) self.debug = LOG.isEnabledFor(logging.DEBUG) self.collections = collections self.num_nodes = num_nodes # Keep track of how many times that we accessed each node self.nodeCounts = Histogram() self.op_count = 0
def computeInStats(query, h=None): for k, v in query.iteritems(): if k == "#in": if h is None: h = Histogram() h.put(len(v)) elif isinstance(v, list): for inner in v: if isinstance(inner, dict): h = computeInStats(inner, h) elif isinstance(v, dict): h = computeInStats(v, h) return h
def computeInStats(query, h=None): for k,v in query.iteritems(): if k == "#in": if h is None: h = Histogram() h.put(len(v)) elif isinstance(v, list): for inner in v: if isinstance(inner, dict): h = computeInStats(inner, h) elif isinstance(v, dict): h = computeInStats(v, h) return h
def fixInvalidCollections(self): searchKey = { "operations.collection": constants.INVALID_COLLECTION_MARKER } for session in self.metadata_db.Session.find(searchKey): for op in session["operations"]: dirty = False if op["collection"] != constants.INVALID_COLLECTION_MARKER: continue if self.debug: LOG.debug("Attempting to fix corrupted Operation:\n%s" % pformat(op)) # For each field referenced in the query, build a histogram of # which collections have a field with the same name fields = workload.getReferencedFields(op) h = Histogram() for c in self.metadata_db.Collection.find(): for f in c['fields']: if f in fields: h.put(c['name']) ## FOR ## FOR matches = h.getMaxCountKeys() if len(matches) == 0: LOG.warn( "No matching collection was found for corrupted operation\n%s" % pformat(op)) continue elif len(matches) > 1: LOG.warn( "More than one matching collection was found for corrupted operation %s\n%s" % (matches, pformat(op))) continue else: op["collection"] = matches[0] dirty = True self.fix_ctr += 1 LOG.info("Fix corrupted collection in operation\n%s" % pformat(op)) ## IF ## FOR (operations) if dirty: session.save()
def testPickle(self): h = Histogram() letters = [x for x in string.letters] + ["-"] for i in xrange(0, 100): key = "" for x in xrange(0, 10): key += random.choice(letters) assert len(key) > 0 h.put(key, delta=random.randint(1, 10)) assert h[key] > 0 ## FOR # Serialize import pickle p = pickle.dumps(h, -1) assert p # Deserialize clone = pickle.loads(p) assert clone for key in h.keys(): self.assertEquals(h[key], clone[key]) ## FOR self.assertEquals(h.getSampleCount(), clone.getSampleCount()) self.assertEquals(sorted(h.getMinCountKeys()), sorted(clone.getMinCountKeys()))
def __init__(self, collections, workload, config): assert isinstance(collections, dict) # LOG.setLevel(logging.DEBUG) self.debug = LOG.isEnabledFor(logging.DEBUG) self.collections = collections self.col_names = [col_name for col_name in collections.iterkeys()] self.workload = None # working workload self.originalWorload = workload # points to the original workload self.weight_network = config.get('weight_network', 1.0) self.weight_disk = config.get('weight_disk', 1.0) self.weight_skew = config.get('weight_skew', 1.0) self.num_nodes = config.get('nodes', 1) # Convert MB to bytes self.max_memory = config['max_memory'] * 1024 * 1024 self.skew_segments = config['skew_intervals'] # Why? "- 1" self.address_size = config['address_size'] / 4 self.estimator = NodeEstimator(collections, self.num_nodes) self.window_size = config['window_size'] # Build indexes from collections to sessions/operations # Note that this won't change dynamically based on denormalization schemes # It's up to the cost components to figure things out based on that self.restoreOriginalWorkload() # We need to know the number of operations in the original workload # so that all of our calculations are based on that self.orig_op_count = 0 for sess in self.originalWorload: self.orig_op_count += len(sess["operations"]) ## FOR ## ---------------------------------------------- ## CACHING ## ---------------------------------------------- self.cache_enable = True self.cache_miss_ctr = Histogram() self.cache_hit_ctr = Histogram() # ColName -> CacheHandle self.cache_handles = {}
def __init__(self, collections, workload, config): assert isinstance(collections, dict) # LOG.setLevel(logging.DEBUG) self.debug = LOG.isEnabledFor(logging.DEBUG) self.collections = collections self.col_names = [col_name for col_name in collections.iterkeys()] self.workload = None # working workload self.originalWorload = workload # points to the original workload self.weight_network = config.get("weight_network", 1.0) self.weight_disk = config.get("weight_disk", 1.0) self.weight_skew = config.get("weight_skew", 1.0) self.num_nodes = config.get("nodes", 1) # Convert MB to bytes self.max_memory = config["max_memory"] * 1024 * 1024 self.skew_segments = config["skew_intervals"] # Why? "- 1" self.address_size = config["address_size"] / 4 self.estimator = NodeEstimator(collections, self.num_nodes) self.window_size = config["window_size"] # Build indexes from collections to sessions/operations # Note that this won't change dynamically based on denormalization schemes # It's up to the cost components to figure things out based on that self.restoreOriginalWorkload() # We need to know the number of operations in the original workload # so that all of our calculations are based on that self.orig_op_count = 0 for sess in self.originalWorload: self.orig_op_count += len(sess["operations"]) ## FOR ## ---------------------------------------------- ## CACHING ## ---------------------------------------------- self.cache_enable = True self.cache_miss_ctr = Histogram() self.cache_hit_ctr = Histogram() # ColName -> CacheHandle self.cache_handles = {}
def __init__(self): self.start = None self.stop = None self.txn_id = 0 self.opCount = 0 self.completed = [ ] # (txnName, timestamp) self.txn_counters = Histogram() self.txn_times = { } self.running = { }
def fixInvalidCollections(self): searchKey = {"operations.collection": constants.INVALID_COLLECTION_MARKER} for session in self.metadata_db.Session.find(searchKey): for op in session["operations"]: dirty = False if op["collection"] != constants.INVALID_COLLECTION_MARKER: continue if self.debug: LOG.debug("Attempting to fix corrupted Operation:\n%s" % pformat(op)) # For each field referenced in the query, build a histogram of # which collections have a field with the same name fields = workload.getReferencedFields(op) h = Histogram() for c in self.metadata_db.Collection.find(): for f in c["fields"]: if f in fields: h.put(c["name"]) ## FOR ## FOR matches = h.getMaxCountKeys() if len(matches) == 0: LOG.warn("No matching collection was found for corrupted operation\n%s" % pformat(op)) continue elif len(matches) > 1: LOG.warn( "More than one matching collection was found for corrupted operation %s\n%s" % (matches, pformat(op)) ) continue else: op["collection"] = matches[0] dirty = True self.fix_ctr += 1 LOG.info("Fix corrupted collection in operation\n%s" % pformat(op)) ## IF ## FOR (operations) if dirty: session.save()
def __init__(self): self.histogram = Histogram() self.debug = LOG.isEnabledFor(logging.DEBUG) pass
class State(): """Cost Model State""" ## ----------------------------------------------------------------------- ## INTERNAL CACHE STATE ## ----------------------------------------------------------------------- class Cache(): """ Internal cache for a single collection. Note that this is different than the LRUBuffer cache stuff. These are cached look-ups that the CostModel uses for figuring out what operations do. """ def __init__(self, col_info, num_nodes): # The number of pages needed to do a full scan of this collection # The worst case for all other operations is if we have to do # a full scan that requires us to evict the entire buffer # Hence, we multiple the max pages by two # self.fullscan_pages = (col_info['max_pages'] * 2) self.fullscan_pages = col_info['doc_count'] * 2 assert self.fullscan_pages > 0,\ "Zero max_pages for collection '%s'" % col_info['name'] # Cache of Best Index Tuples # QueryHash -> BestIndex self.best_index = {} # Cache of Regex Operations # QueryHash -> Boolean self.op_regex = {} # Cache of Touched Node Ids # QueryId -> [NodeId] self.op_nodeIds = {} # Cache of Document Ids # QueryId -> Index/Collection DocumentIds self.collection_docIds = {} self.index_docIds = {} ## DEF def reset(self): self.best_index.clear() self.op_regex.clear() self.op_nodeIds.clear() self.collection_docIds.clear() self.index_docIds.clear() self.op_count = 0 self.msg_count = 0 self.network_reset = True ## DEF def __str__(self): ret = "" max_len = max(map(len, self.__dict__.iterkeys())) + 1 f = " %-" + str(max_len) + "s %s\n" for k, v in self.__dict__.iteritems(): if isinstance(v, dict): v_str = "[%d entries]" % len(v) else: v_str = str(v) ret += f % (k + ":", v_str) return ret ## DEF ## CLASS def __init__(self, collections, workload, config): assert isinstance(collections, dict) # LOG.setLevel(logging.DEBUG) self.debug = LOG.isEnabledFor(logging.DEBUG) self.collections = collections self.col_names = [col_name for col_name in collections.iterkeys()] self.workload = None # working workload self.originalWorload = workload # points to the original workload self.weight_network = config.get('weight_network', 1.0) self.weight_disk = config.get('weight_disk', 1.0) self.weight_skew = config.get('weight_skew', 1.0) self.num_nodes = config.get('nodes', 1) # Convert MB to bytes self.max_memory = config['max_memory'] * 1024 * 1024 self.skew_segments = config['skew_intervals'] # Why? "- 1" self.address_size = config['address_size'] / 4 self.estimator = NodeEstimator(collections, self.num_nodes) self.window_size = config['window_size'] # Build indexes from collections to sessions/operations # Note that this won't change dynamically based on denormalization schemes # It's up to the cost components to figure things out based on that self.restoreOriginalWorkload() # We need to know the number of operations in the original workload # so that all of our calculations are based on that self.orig_op_count = 0 for sess in self.originalWorload: self.orig_op_count += len(sess["operations"]) ## FOR ## ---------------------------------------------- ## CACHING ## ---------------------------------------------- self.cache_enable = True self.cache_miss_ctr = Histogram() self.cache_hit_ctr = Histogram() # ColName -> CacheHandle self.cache_handles = {} ## DEF def init_xref(self, workload): ''' initialize the cross reference based on the current working workload ''' self.col_sess_xref = dict([(col_name, []) for col_name in self.col_names]) self.col_op_xref = dict([(col_name, []) for col_name in self.col_names]) self.__buildCrossReference__(workload) ## DEF def updateWorkload(self, workload): self.workload = workload self.init_xref(workload) ## DEF def restoreOriginalWorkload(self): self.workload = self.originalWorload self.init_xref(self.workload) ## DEF def __buildCrossReference__(self, workload): for sess in workload: cols = set() for op in sess["operations"]: col_name = op["collection"] if col_name in self.col_sess_xref: self.col_op_xref[col_name].append(op) cols.add(col_name) ## FOR (op) for col_name in cols: self.col_sess_xref[col_name].append(sess) ## FOR (sess) def invalidateCache(self, col_name): if col_name in self.cache_handles: if self.debug: LOG.debug("Invalidating cache for collection '%s'", col_name) self.cache_handles[col_name].reset() ## DEF def getCacheHandleByName(self, col_info): """ Return a cache handle for the given collection name. This is the preferrred method because it requires fewer hashes """ cache = self.cache_handles.get(col_info['name'], None) if cache is None: cache = State.Cache(col_info, self.num_nodes) self.cache_handles[col_info['name']] = cache return cache ## DEF def getCacheHandle(self, col_info): return self.getCacheHandleByName(col_info) ## DEF def reset(self): """ Reset all of the internal state and cache information """ # Clear out caches for all collections self.cache_handles.clear() self.estimator.reset() ## ----------------------------------------------------------------------- ## UTILITY CODE ## ----------------------------------------------------------------------- def __getIsOpRegex__(self, cache, op): isRegex = cache.op_regex.get(op["query_hash"], None) if isRegex is None: isRegex = workload.isOpRegex(op) if self.cache_enable: if self.debug: self.cache_miss_ctr.put("op_regex") cache.op_regex[op["query_hash"]] = isRegex elif self.debug: self.cache_hit_ctr.put("op_regex") return isRegex ## DEF def __getNodeIds__(self, cache, design, op): node_ids = cache.op_nodeIds.get(op['query_id'], None) if node_ids is None: try: node_ids = self.estimator.estimateNodes(design, op) except: if self.debug: LOG.error( "Failed to estimate touched nodes for op #%d\n%s", op['query_id'], pformat(op)) raise if self.cache_enable: if self.debug: self.cache_miss_ctr.put("op_nodeIds") cache.op_nodeIds[op['query_id']] = node_ids if self.debug: LOG.debug("Estimated Touched Nodes for Op #%d: %d", op['query_id'], len(node_ids)) elif self.debug: self.cache_hit_ctr.put("op_nodeIds") return node_ids ## DEF ## CLASS
class NodeEstimator(object): def __init__(self, collections, num_nodes): assert isinstance(collections, dict) # LOG.setLevel(logging.DEBUG) self.debug = LOG.isEnabledFor(logging.DEBUG) self.collections = collections self.num_nodes = num_nodes # Keep track of how many times that we accessed each node self.nodeCounts = Histogram() self.op_count = 0 ## DEF def reset(self): """ Reset internal counters for this estimator. This should be called everytime we start evaluating a new design """ self.nodeCounts.clear() self.op_count = 0 ## DEF def estimateNodes(self, design, op): """ For the given operation and a design object, return an estimate of a list of node ids that we think that the query will be executed on """ results = set() broadcast = True shardingKeys = design.getShardKeys(op['collection']) if self.debug: LOG.debug("Computing node estimate for Op #%d [sharding=%s]", \ op['query_id'], shardingKeys) # Inserts always go to a single node if op['type'] == constants.OP_TYPE_INSERT: # Get the documents that they're trying to insert and then # compute their hashes based on the sharding key # Because there is no logical replication, each document will # be inserted in one and only one node for content in workload.getOpContents(op): values = catalog.getFieldValues(shardingKeys, content) results.add(self.computeTouchedNode(values)) ## FOR broadcast = False # Network costs of SELECT, UPDATE, DELETE queries are based off # of using the sharding key in the predicate elif len(op['predicates']) > 0: predicate_types = set() for k,v in op['predicates'].iteritems() : if design.inShardKeyPattern(op['collection'], k) : broadcast = False predicate_types.add(v) if self.debug: LOG.debug("Op #%d %s Predicates: %s [broadcast=%s / predicateTypes=%s]",\ op['query_id'], op['collection'], op['predicates'], broadcast, list(predicate_types)) ## ---------------------------------------------- ## PRED_TYPE_REGEX ## ---------------------------------------------- if not broadcast and constants.PRED_TYPE_REGEX in predicate_types: # Any query that is using a regex on the sharding key must be broadcast to every node # It's not complete accurate but it's just easier that way broadcast = True ## ---------------------------------------------- ## PRED_TYPE_RANGE ## ---------------------------------------------- elif not broadcast and constants.PRED_TYPE_RANGE in predicate_types: # If it's a scan, then we need to first figure out what # node they will start the scan at, and then just approximate # what it will do by adding N nodes to the touched list starting # from that first node. We will wrap around to zero num_touched = self.guessNodes(design, op['collection'], k) if self.debug: LOG.info("Estimating that Op #%d on '%s' touches %d nodes",\ op["query_id"], op["collection"], num_touched) for content in workload.getOpContents(op): values = catalog.getFieldValues(shardingKeys, content) if self.debug: LOG.debug("%s -> %s", shardingKeys, values) try: node_id = self.computeTouchedNode(values) except: if self.debug: LOG.error("Unexpected error when computing touched nodes\n%s" % pformat(values)) raise for i in xrange(num_touched): if node_id >= self.num_nodes: node_id = 0 results.add(node_id) node_id += 1 ## FOR ## FOR ## ---------------------------------------------- ## PRED_TYPE_EQUALITY ## ---------------------------------------------- elif not broadcast and constants.PRED_TYPE_EQUALITY in predicate_types: broadcast = False for content in workload.getOpContents(op): values = catalog.getFieldValues(shardingKeys, content) results.add(self.computeTouchedNode(values)) ## FOR ## ---------------------------------------------- ## BUSTED! ## ---------------------------------------------- elif not broadcast: raise Exception("Unexpected predicate types '%s' for op #%d" % (list(predicate_types), op['query_id'])) ## IF if broadcast: if self.debug: LOG.debug("Op #%d on '%s' is a broadcast query to all nodes",\ op["query_id"], op["collection"]) map(results.add, xrange(0, self.num_nodes)) map(self.nodeCounts.put, results) self.op_count += 1 return results ## DEF def computeTouchedNode(self, values): """ Compute which node the given set of values will need to go This is just a simple (hash % N), where N is the number of nodes in the cluster """ assert isinstance(values, tuple) return hash(values) % self.num_nodes ## DEF def guessNodes(self, design, colName, fieldName): """ Return the number of nodes that a query accessing a collection using the given field will touch. This serves as a stand-in for the EXPLAIN function referenced in the paper """ col_info = self.collections[colName] if not fieldName in col_info['fields']: raise Exception("Invalid field '%s.%s" % (colName, fieldName)) field = col_info['fields'][fieldName] # TODO: How do we use the statistics to determine the selectivity of this particular # attribute and thus determine the number of nodes required to answer the query? return int(math.ceil(field['selectivity'] * self.num_nodes)) ## DEF def getOpCount(self): """Return the number of operations evaluated""" return self.op_count ## CLASS
class Results: def __init__(self, config=None): self.start = None self.stop = None self.txn_id = 0 self.opCount = 0 self.completed = [ ] # (txnName, timestamp) self.txn_counters = Histogram() self.txn_times = { } self.running = { } self.config = config def startBenchmark(self): """Mark the benchmark as having been started""" assert self.start == None LOG.debug("Starting benchmark statistics collection") self.start = time.time() return self.start def stopBenchmark(self): """Mark the benchmark as having been stopped""" assert self.start != None assert self.stop == None LOG.debug("Stopping benchmark statistics collection") self.stop = time.time() def startTransaction(self, txn): self.txn_id += 1 id = self.txn_id self.running[id] = (txn, time.time()) return id def abortTransaction(self, id): """Abort a transaction and discard its times""" assert id in self.running txn_name, txn_start = self.running[id] del self.running[id] def stopTransaction(self, id, opCount, latencies=[]): """Record that the benchmark completed an invocation of the given transaction""" assert id in self.running timestamp = time.time() txn_name, txn_start = self.running[id] del self.running[id] self.completed.append((txn_name, timestamp, latencies)) duration = timestamp - txn_start total_time = self.txn_times.get(txn_name, 0) self.txn_times[txn_name] = total_time + duration # OpCount if opCount is not None: self.opCount += opCount else: LOG.debug("ithappens") # Txn Counter Histogram self.txn_counters.put(txn_name) assert self.txn_counters[txn_name] > 0 if LOG.isEnabledFor(logging.DEBUG): LOG.debug("Completed %s in %f sec" % (txn_name, duration)) ## DEF @staticmethod def show_table(title, headers, table, line_width): cols_width = [len(header) for header in headers] for row in table: row_width = 0 for i in range(len(headers)): if len(row[i]) > cols_width[i]: cols_width[i] = len(row[i]) row_width += cols_width[i] row_width += 4 * (len(headers) - 1) if row_width > line_width: line_width = row_width output = ("%s\n" % ("=" * line_width)) output += ("%s\n" % title) output += ("%s\n" % ("-" * line_width)) for i in range(len(headers)): header = headers[i] output += ("%s%s" % (header, " " * (cols_width[i] - len(header)))) if i != len(headers) - 1: output += " " * 4 output += "\n" for row in table: for i in range(len(headers)): cell = row[i] output += ("%s%s" % (cell, " " * (cols_width[i] - len(cell)))) if i != len(headers) - 1: output += " " * 4 output += "\n" output += ("%s\n" % ("-" * line_width)) return output, line_width def show_latencies(self, line_width): latencies = [] output = "" for txn_stats in self.completed: latencies.extend(txn_stats[2]) if len(latencies) > 0: latencies = sorted(latencies, key=itemgetter(0)) percents = [0.1, 0.2, 0.5, 0.8, 0.9, 0.999] latency_table = [] slowest_ops = [] for percent in percents: index = int(math.floor(percent * len(latencies))) percent_str = "%0.1f%%" % (percent * 100) millis_sec_str = "%0.4f" % (latencies[index][0]) latency_table.append((percent_str, millis_sec_str)) latency_headers = ["Queries(%)", "Latency(ms)"] output, line_width = \ Results.show_table("Latency Report", latency_headers, latency_table, line_width) if self.config is not None and self.config["default"]["slow_ops_num"] > 0: num_ops = self.config["default"]["slow_ops_num"] slowest_ops_headers = ["#", "Latency(ms)", "Session Id", "Operation Id", "Type", "Collection", "Predicates"] for i in range(num_ops): if i < len(latencies): slowest_ops.append([ "%d" % i, "%0.4f" % (latencies[len(latencies) - i - 1][0]), str(latencies[len(latencies) - i - 1][1]), str(latencies[len(latencies) - i - 1][2]), latencies[len(latencies) - i - 1][3], latencies[len(latencies) - i - 1][4], json.dumps(latencies[len(latencies) - i - 1][5]) ]) slowest_ops_output, line_width = \ Results.show_table("Top %d Slowest Operations" % num_ops, slowest_ops_headers, slowest_ops, line_width) output += ("\n%s" % slowest_ops_output) return output def append(self, r): self.opCount += r.opCount for txn_name in r.txn_counters.keys(): self.txn_counters.put(txn_name, delta=r.txn_counters[txn_name]) orig_time = self.txn_times.get(txn_name, 0) self.txn_times[txn_name] = orig_time + r.txn_times[txn_name] #LOG.info("resOps="+str(r.opCount)) #LOG.debug("%s [cnt=%d, time=%d]" % (txn_name, self.txn_counters[txn_name], self.txn_times[txn_name])) ## HACK if type(r.completed) == list: self.completed.extend(r.completed) if not self.start: self.start = r.start else: self.start = min(self.start, r.start) if not self.stop: self.stop = r.stop else: self.stop = max(self.stop, r.stop) ## DEF def __str__(self): return self.show() def show(self, load_time = None): if self.start == None: msg = "Attempting to get benchmark results before it was started" raise Exception(msg) LOG.warn(msg) return "Benchmark not started" if self.stop == None: duration = time.time() - self.start else: duration = self.stop - self.start col_width = 18 total_width = (col_width*4)+2 f = "\n " + (("%-" + str(col_width) + "s")*4) line = "-"*total_width ret = u"" + "="*total_width + "\n" if load_time != None: ret += "Data Loading Time: %d seconds\n\n" % (load_time) ret += "Execution Results after %d seconds\n%s" % (duration, line) ret += f % ("", "Executed", u"Total Time (ms)", "Rate") total_time = duration total_cnt = self.txn_counters.getSampleCount() #total_running_time = 0 for txn in sorted(self.txn_counters.keys()): txn_time = self.txn_times[txn] txn_cnt = "%6d - %4.1f%%" % (self.txn_counters[txn], (self.txn_counters[txn] / float(total_cnt))*100) rate = u"%.02f txn/s" % ((self.txn_counters[txn] / total_time)) #total_running_time +=txn_time #rate = u"%.02f op/s" % ((self.txn_counters[txn] / total_time)) rate = u"%.02f op/s" % ((self.opCount / total_time)) ret += f % (txn, txn_cnt, str(txn_time * 1000), rate) #LOG.info("totalOps="+str(self.totalOps)) # total_time += txn_time ret += "\n" + ("-"*total_width) rate = 0 if total_time > 0: rate = total_cnt / float(total_time) # TXN RATE rate = total_cnt / float(total_time) #total_rate = "%.02f txn/s" % rate total_rate = "%.02f op/s" % rate #total_rate = str(rate) ret += f % ("TOTAL", str(total_cnt), str(total_time*1000), total_rate) return ("%s\n%s" % (ret, self.show_latencies(total_width))).encode('utf-8')
def main(): # parser = optparse.OptionParser() parser = argparse.ArgumentParser(description='') parser.add_argument('-v', '--verbose', dest='verbose', action='count', help='Increase verbosity (specify' ' multiple times for more)') parser.add_argument('-g', '--print-hist', action='store_true', dest='hist', help='Print request latency histogram', default=False) parser.add_argument('-c', '--cores', dest='cores', action='store', help='Set the number of cores of the system', default=8) parser.add_argument('-n', '--network-cores', dest='network_cores', action='store', help='Set the number of networking' ' cores of the system', default=0) parser.add_argument('-s', '--seed', dest='seed', action='store', help='Set the seed for request generator') parser.add_argument('-t', '--sim_time', dest='sim_time', action='store', help='Set the simulation time', default=500000) parser.add_argument('--workload-conf', dest='work_conf', action='store', help='Configuration file for the load generation' ' functions', default="../config/work.json") group = parser.add_argument_group('Host Options') group.add_argument('--host-type', dest='host_type', action='store', help=('Set the host configuration (global queue,' ' local queue, shinjuku, per flow queues,' ' static core allocation)'), default='global') group.add_argument('--deq-cost', dest='deq_cost', action='store', help='Set the dequeuing cost', default=0.0) group.add_argument('--queue-policy', dest='queue_policy', action='store', help=('Set the queue policy to be followed by the per' ' flow queue, ignored in any other queue' ' configuration'), default='FlowQueues') parser.add_argument_group(group) group = parser.add_argument_group('Print Options') group.add_argument('--print-values', dest='print_values', action='store_true', help='Print all the latencies for' ' each flow', default=False) group.add_argument('--output-file', dest='output_file', action='store', help='File to print all latencies', default=None) opts = parser.parse_args() # Seeding if opts.seed: random.seed(int(opts.seed)) np.random.seed(int(opts.seed)) # Setup logging log_level = logging.WARNING if opts.verbose == 1: log_level = logging.INFO elif opts.verbose >= 2: log_level = logging.DEBUG logging.basicConfig(level=log_level) # Initialize the different components of the system env = simpy.Environment() # Parse the configuration file flow_config = json.loads(open(opts.work_conf).read()) # Create a histogram per flow and a global histogram histograms = Histogram(len(flow_config), float(opts.cores), flow_config, opts) # Get the queue configuration host_conf = getattr(sys.modules[__name__], gen_dict[opts.host_type]) sim_host = host_conf(env, int(opts.cores), histograms, float(opts.deq_cost), flow_config, opts) # TODO:Update so that it's parametrizable # print "Warning: Need to update sim.py for parameterization and Testing" # First list is time slice, second list is load # sim_host = StaticCoreAllocationHost(env, int(opts.cores), # float(opts.deq_cost), [0.0, 0.0], # histograms, len(flow_config), # [0.4, 0.4]) multigenerator = MultipleRequestGenerator(env, sim_host) # Create one object per flow for flow in flow_config: params = flow #work_gen = getattr(sys.modules[__name__], # gen_dict[params["work_gen"]]) # Need to generate less load when we have shinjuku because one # of the cores is just the dispatcher if (opts.host_type == "shinjuku"): opts.cores = int(opts.cores) - 1 multigenerator.add_generator( RequestGenerator(env, sim_host, int(opts.cores), params)) multigenerator.begin_generation() # Run the simulation env.run(until=opts.sim_time) # Print results in json format histograms.print_info()
def hash(self, op): """Compute a deterministic signature for the given operation based on its keys""" fields = None updateFields = None # QUERY if op["type"] == constants.OP_TYPE_QUERY: # The query field has our where clause if not "#query" in op["query_content"][0]: msg = "Missing query field in query_content for operation #%d" % op["query_id"] if self.debug: LOG.warn(pformat(op)) raise Exception(msg) fields = op["query_content"][0][constants.REPLACE_KEY_DOLLAR_PREFIX + "query"] # UPDATE elif op["type"] == constants.OP_TYPE_UPDATE: # The first element in the content field is the WHERE clause fields = op["query_content"][0] # We use a separate field for the updated columns so that updateFields = op['query_content'][1] # INSERT elif op["type"] == constants.OP_TYPE_INSERT: # They could be inserting more than one document here, # which all may have different fields... # So we will need to build a histogram for which keys are referenced # and use the onese that appear the most # XXX: We'll only consider keys in the first-level h = Histogram() for doc in op["query_content"]: assert type(doc) == dict, "Unexpected insert value:\n%s" % pformat(doc) for k in doc.keys(): h.put(k) ## FOR if LOG.isEnabledFor(logging.DEBUG): LOG.debug("Insert '%s' Keys Histogram:\n%s" % (op["collection"], h)) maxKeys = h.getMaxCountKeys() assert len(maxKeys) > 0, \ "No keys were found in %d insert documents?" % len(op["query_content"]) fields = { } for doc in op["query_content"]: for k, v in doc.iteritems(): if k in maxKeys: fields[k] = v ## FOR ## FOR # DELETE elif op["type"] == constants.OP_TYPE_DELETE: # The first element in the content field is the WHERE clause fields = op["query_content"][0] # UNKNOWN! else: raise Exception("Unexpected query type: %s" % op["type"]) # Extract the list of fields that are used try: fieldsHash = self.computeFieldsHash(fields) except: LOG.error("Unexpected error when processing operation %d [fields=%s]" % (op["query_id"], str(fields))) raise updateHash = self.computeFieldsHash(updateFields) if updateFields else None t = (op["collection"], op["type"], fieldsHash, updateHash) h = long(hash(t)) LOG.debug("%s %s => HASH:%d" % (fields, t, h)) self.histogram.put(h) return h
class Results: def __init__(self): self.start = None self.stop = None self.txn_id = 0 self.opCount = 0 self.completed = [] # (txnName, timestamp) self.txn_counters = Histogram() self.txn_times = {} self.running = {} def startBenchmark(self): """Mark the benchmark as having been started""" assert self.start == None LOG.debug("Starting benchmark statistics collection") self.start = time.time() return self.start def stopBenchmark(self): """Mark the benchmark as having been stopped""" assert self.start != None assert self.stop == None LOG.debug("Stopping benchmark statistics collection") self.stop = time.time() def startTransaction(self, txn): self.txn_id += 1 id = self.txn_id self.running[id] = (txn, time.time()) return id def abortTransaction(self, id): """Abort a transaction and discard its times""" assert id in self.running txn_name, txn_start = self.running[id] del self.running[id] def stopTransaction(self, id, opCount): """Record that the benchmark completed an invocation of the given transaction""" assert id in self.running timestamp = time.time() txn_name, txn_start = self.running[id] del self.running[id] self.completed.append((txn_name, timestamp)) duration = timestamp - txn_start total_time = self.txn_times.get(txn_name, 0) self.txn_times[txn_name] = total_time + duration # OpCount if opCount is not None: self.opCount += opCount else: LOG.debug("ithappens") # Txn Counter Histogram self.txn_counters.put(txn_name) assert self.txn_counters[txn_name] > 0 if LOG.isEnabledFor(logging.DEBUG): LOG.debug("Completed %s in %f sec" % (txn_name, duration)) ## DEF def append(self, r): self.opCount += r.opCount for txn_name in r.txn_counters.keys(): self.txn_counters.put(txn_name, delta=r.txn_counters[txn_name]) orig_time = self.txn_times.get(txn_name, 0) self.txn_times[txn_name] = orig_time + r.txn_times[txn_name] #LOG.info("resOps="+str(r.opCount)) #LOG.debug("%s [cnt=%d, time=%d]" % (txn_name, self.txn_counters[txn_name], self.txn_times[txn_name])) ## HACK if type(r.completed) == list: self.completed.extend(r.completed) if not self.start: self.start = r.start else: self.start = min(self.start, r.start) if not self.stop: self.stop = r.stop else: self.stop = max(self.stop, r.stop) ## DEF def __str__(self): return self.show() def show(self, load_time=None): if self.start == None: msg = "Attempting to get benchmark results before it was started" raise Exception(msg) LOG.warn(msg) return "Benchmark not started" if self.stop == None: duration = time.time() - self.start else: duration = self.stop - self.start col_width = 18 total_width = (col_width * 4) + 2 f = "\n " + (("%-" + str(col_width) + "s") * 4) line = "-" * total_width ret = u"" + "=" * total_width + "\n" if load_time != None: ret += "Data Loading Time: %d seconds\n\n" % (load_time) ret += "Execution Results after %d seconds\n%s" % (duration, line) ret += f % ("", "Executed", u"Total Time (ms)", "Rate") total_time = duration total_cnt = self.txn_counters.getSampleCount() #total_running_time = 0 for txn in sorted(self.txn_counters.keys()): txn_time = self.txn_times[txn] txn_cnt = "%6d - %4.1f%%" % ( self.txn_counters[txn], (self.txn_counters[txn] / float(total_cnt)) * 100) rate = u"%.02f txn/s" % ((self.txn_counters[txn] / total_time)) #total_running_time +=txn_time #rate = u"%.02f op/s" % ((self.txn_counters[txn] / total_time)) #rate = u"%.02f op/s" % ((self.opCount / total_time)) ret += f % (txn, txn_cnt, str(txn_time * 1000), rate) #LOG.info("totalOps="+str(self.totalOps)) # total_time += txn_time ret += "\n" + ("-" * total_width) rate = 0 if total_time > 0: rate = total_cnt / float(total_time) # TXN RATE rate = total_cnt / float(total_time) #total_rate = "%.02f txn/s" % rate total_rate = "%.02f op/s" % rate #total_rate = str(rate) ret += f % ("TOTAL", str(total_cnt), str( total_time * 1000), total_rate) return (ret.encode('utf-8'))
"workload_percent", ] STRIP_FIELDS = [ "predicates", "query_hash", "query_time", "query_size", "query_type", "query_id", "orig_query", "resp_.*", ] STRIP_REGEXES = [ re.compile(r) for r in STRIP_FIELDS ] QUERY_COUNTS = Histogram() QUERY_COLLECTION_COUNTS = Histogram() QUERY_HASH_XREF = { } QUERY_TOP_LIMIT = 10 ## ============================================== ## DUMP SCHEMA ## ============================================== def dumpSchema(writer, collection, fields, spacer=""): cur_spacer = spacer if len(spacer) > 0: cur_spacer += " - " for f_name in sorted(fields.iterkeys(), key=lambda x: x != "_id"): row = [ ] f = fields[f_name] for key in SCHEMA_COLUMNS: if key == SCHEMA_COLUMNS[0]:
class NodeEstimator(object): def __init__(self, collections, num_nodes): assert isinstance(collections, dict) # LOG.setLevel(logging.DEBUG) self.debug = LOG.isEnabledFor(logging.DEBUG) self.collections = collections self.num_nodes = num_nodes # Keep track of how many times that we accessed each node self.nodeCounts = Histogram() self.op_count = 0 ## DEF def reset(self): """ Reset internal counters for this estimator. This should be called everytime we start evaluating a new design """ self.nodeCounts.clear() self.op_count = 0 ## DEF def estimateNodes(self, design, op): """ For the given operation and a design object, return an estimate of a list of node ids that we think that the query will be executed on """ results = set() broadcast = True shardingKeys = design.getShardKeys(op['collection']) if self.debug: LOG.debug("Computing node estimate for Op #%d [sharding=%s]", \ op['query_id'], shardingKeys) # Inserts always go to a single node if op['type'] == constants.OP_TYPE_INSERT: # Get the documents that they're trying to insert and then # compute their hashes based on the sharding key # Because there is no logical replication, each document will # be inserted in one and only one node for content in workload.getOpContents(op): values = catalog.getFieldValues(shardingKeys, content) results.add(self.computeTouchedNode(values)) ## FOR broadcast = False # Network costs of SELECT, UPDATE, DELETE queries are based off # of using the sharding key in the predicate elif len(op['predicates']) > 0: predicate_types = set() for k, v in op['predicates'].iteritems(): if design.inShardKeyPattern(op['collection'], k): broadcast = False predicate_types.add(v) if self.debug: LOG.debug("Op #%d %s Predicates: %s [broadcast=%s / predicateTypes=%s]",\ op['query_id'], op['collection'], op['predicates'], broadcast, list(predicate_types)) ## ---------------------------------------------- ## PRED_TYPE_REGEX ## ---------------------------------------------- if not broadcast and constants.PRED_TYPE_REGEX in predicate_types: # Any query that is using a regex on the sharding key must be broadcast to every node # It's not complete accurate but it's just easier that way broadcast = True ## ---------------------------------------------- ## PRED_TYPE_RANGE ## ---------------------------------------------- elif not broadcast and constants.PRED_TYPE_RANGE in predicate_types: # If it's a scan, then we need to first figure out what # node they will start the scan at, and then just approximate # what it will do by adding N nodes to the touched list starting # from that first node. We will wrap around to zero num_touched = self.guessNodes(design, op['collection'], k) if self.debug: LOG.info("Estimating that Op #%d on '%s' touches %d nodes",\ op["query_id"], op["collection"], num_touched) for content in workload.getOpContents(op): values = catalog.getFieldValues(shardingKeys, content) if self.debug: LOG.debug("%s -> %s", shardingKeys, values) try: node_id = self.computeTouchedNode(values) except: if self.debug: LOG.error( "Unexpected error when computing touched nodes\n%s" % pformat(values)) raise for i in xrange(num_touched): if node_id >= self.num_nodes: node_id = 0 results.add(node_id) node_id += 1 ## FOR ## FOR ## ---------------------------------------------- ## PRED_TYPE_EQUALITY ## ---------------------------------------------- elif not broadcast and constants.PRED_TYPE_EQUALITY in predicate_types: broadcast = False for content in workload.getOpContents(op): values = catalog.getFieldValues(shardingKeys, content) results.add(self.computeTouchedNode(values)) ## FOR ## ---------------------------------------------- ## BUSTED! ## ---------------------------------------------- elif not broadcast: raise Exception("Unexpected predicate types '%s' for op #%d" % (list(predicate_types), op['query_id'])) ## IF if broadcast: if self.debug: LOG.debug("Op #%d on '%s' is a broadcast query to all nodes",\ op["query_id"], op["collection"]) map(results.add, xrange(0, self.num_nodes)) map(self.nodeCounts.put, results) self.op_count += 1 return results ## DEF def computeTouchedNode(self, values): """ Compute which node the given set of values will need to go This is just a simple (hash % N), where N is the number of nodes in the cluster """ assert isinstance(values, tuple) return hash(values) % self.num_nodes ## DEF def guessNodes(self, design, colName, fieldName): """ Return the number of nodes that a query accessing a collection using the given field will touch. This serves as a stand-in for the EXPLAIN function referenced in the paper """ col_info = self.collections[colName] if not fieldName in col_info['fields']: raise Exception("Invalid field '%s.%s" % (colName, fieldName)) field = col_info['fields'][fieldName] # TODO: How do we use the statistics to determine the selectivity of this particular # attribute and thus determine the number of nodes required to answer the query? return int(math.ceil(field['selectivity'] * self.num_nodes)) ## DEF def getOpCount(self): """Return the number of operations evaluated""" return self.op_count ## CLASS
class Results: def __init__(self): self.start = None self.stop = None self.txn_id = 0 self.opCount = 0 self.completed = [ ] # (txnName, timestamp) self.txn_counters = Histogram() self.txn_times = { } self.running = { } def startBenchmark(self): """Mark the benchmark as having been started""" assert self.start == None LOG.debug("Starting benchmark statistics collection") self.start = time.time() return self.start def stopBenchmark(self): """Mark the benchmark as having been stopped""" assert self.start != None assert self.stop == None LOG.debug("Stopping benchmark statistics collection") self.stop = time.time() def startTransaction(self, txn): self.txn_id += 1 id = self.txn_id self.running[id] = (txn, time.time()) return id def abortTransaction(self, id): """Abort a transaction and discard its times""" assert id in self.running txn_name, txn_start = self.running[id] del self.running[id] def stopTransaction(self, id, opCount): """Record that the benchmark completed an invocation of the given transaction""" assert id in self.running timestamp = time.time() txn_name, txn_start = self.running[id] del self.running[id] self.completed.append((txn_name, timestamp)) duration = timestamp - txn_start total_time = self.txn_times.get(txn_name, 0) self.txn_times[txn_name] = total_time + duration # OpCount if opCount is not None: self.opCount += opCount else: LOG.debug("ithappens") # Txn Counter Histogram self.txn_counters.put(txn_name) assert self.txn_counters[txn_name] > 0 if LOG.isEnabledFor(logging.DEBUG): LOG.debug("Completed %s in %f sec" % (txn_name, duration)) ## DEF def append(self, r): self.opCount += r.opCount for txn_name in r.txn_counters.keys(): self.txn_counters.put(txn_name, delta=r.txn_counters[txn_name]) orig_time = self.txn_times.get(txn_name, 0) self.txn_times[txn_name] = orig_time + r.txn_times[txn_name] #LOG.info("resOps="+str(r.opCount)) #LOG.debug("%s [cnt=%d, time=%d]" % (txn_name, self.txn_counters[txn_name], self.txn_times[txn_name])) ## HACK if type(r.completed) == list: self.completed.extend(r.completed) if not self.start: self.start = r.start else: self.start = min(self.start, r.start) if not self.stop: self.stop = r.stop else: self.stop = max(self.stop, r.stop) ## DEF def __str__(self): return self.show() def show(self, load_time = None): if self.start == None: msg = "Attempting to get benchmark results before it was started" raise Exception(msg) LOG.warn(msg) return "Benchmark not started" if self.stop == None: duration = time.time() - self.start else: duration = self.stop - self.start col_width = 18 total_width = (col_width*4)+2 f = "\n " + (("%-" + str(col_width) + "s")*4) line = "-"*total_width ret = u"" + "="*total_width + "\n" if load_time != None: ret += "Data Loading Time: %d seconds\n\n" % (load_time) ret += "Execution Results after %d seconds\n%s" % (duration, line) ret += f % ("", "Executed", u"Total Time (ms)", "Rate") total_time = duration total_cnt = self.txn_counters.getSampleCount() #total_running_time = 0 for txn in sorted(self.txn_counters.keys()): txn_time = self.txn_times[txn] txn_cnt = "%6d - %4.1f%%" % (self.txn_counters[txn], (self.txn_counters[txn] / float(total_cnt))*100) rate = u"%.02f txn/s" % ((self.txn_counters[txn] / total_time)) #total_running_time +=txn_time #rate = u"%.02f op/s" % ((self.txn_counters[txn] / total_time)) #rate = u"%.02f op/s" % ((self.opCount / total_time)) ret += f % (txn, txn_cnt, str(txn_time * 1000), rate) #LOG.info("totalOps="+str(self.totalOps)) # total_time += txn_time ret += "\n" + ("-"*total_width) rate = 0 if total_time > 0: rate = total_cnt / float(total_time) # TXN RATE rate = total_cnt / float(total_time) #total_rate = "%.02f txn/s" % rate total_rate = "%.02f op/s" % rate #total_rate = str(rate) ret += f % ("TOTAL", str(total_cnt), str(total_time*1000), total_rate) return (ret.encode('utf-8'))
"workload_percent", ] STRIP_FIELDS = [ "predicates", "query_hash", "query_time", "query_size", "query_type", "query_id", "orig_query", "resp_.*", ] STRIP_REGEXES = [re.compile(r) for r in STRIP_FIELDS] QUERY_COUNTS = Histogram() QUERY_COLLECTION_COUNTS = Histogram() QUERY_HASH_XREF = {} QUERY_TOP_LIMIT = 10 ## ============================================== ## DUMP SCHEMA ## ============================================== def dumpSchema(writer, collection, fields, spacer=""): cur_spacer = spacer if len(spacer) > 0: cur_spacer += " - " for f_name in sorted(fields.iterkeys(), key=lambda x: x != "_id"): row = [] f = fields[f_name] for key in SCHEMA_COLUMNS:
def processDataFields(self, col_info, fields, doc): """ Recursively traverse a single document and extract out the field information """ if self.debug: LOG.debug("Extracting fields for document:\n%s" % pformat(doc)) # Check if the current doc has parent_col, but this will only apply to its fields parent_col = doc.get('parent_col', None) for k, v in doc.iteritems(): # Skip if this is the _id field if constants.SKIP_MONGODB_ID_FIELD and k == '_id': continue if k == constants.FUNCTIONAL_FIELD: continue f_type = type(v) f_type_str = catalog.fieldTypeToString(f_type) if not k in fields: # This is only subset of what we will compute for each field # See catalog.Collection for more information if self.debug: LOG.debug("Creating new field entry for '%s'" % k) fields[k] = catalog.Collection.fieldFactory(k, f_type_str) else: fields[k]['type'] = f_type_str # Sanity check # This won't work if the data is not uniform #if v != None: #assert fields[k]['type'] == f_type_str, \ #"Mismatched field types '%s' <> '%s' for '%s'" % (fields[k]['type'], f_type_str, k) # We will store the distinct values for each field in a set # that is embedded in the field. We will delete it when # we call computeFieldStats() if not 'distinct_values' in fields[k]: fields[k]['distinct_values'] = set() if not "num_values" in fields[k]: fields[k]['num_values'] = 0 # Likewise, we will also store a histogram for the different sizes # of each field. We will use this later on to compute the weighted average if not 'size_histogram' in fields[k]: fields[k]['size_histogram'] = Histogram() # Maintain a histogram of list lengths if not 'list_len' in fields[k]: fields[k]['list_len'] = Histogram() if fields[k]['query_use_count'] > 0 and not k in col_info[ 'interesting']: col_info['interesting'].append(k) ## ---------------------------------------------- ## NESTED FIELDS ## ---------------------------------------------- if isinstance(v, dict): # Check for a special data field if len(v) == 1 and v.keys()[0].startswith( constants.REPLACE_KEY_DOLLAR_PREFIX): v = v[v.keys()[0]] # HACK to handle lists (hopefully dict as well)from nested IN clauses... all_values = v if isinstance(v, list) else [v] for v in all_values: if isinstance(v, dict): v = v.values()[0] fields[k]['type'] = catalog.fieldTypeToString(type(v)) try: size = catalog.getEstimatedSize( fields[k]['type'], v) self.total_field_ctr += 1 except: if self.debug: LOG.error("Failed to estimate size for field '%s' in collection '%s'\n%s", \ k, col_info['name'], pformat(fields[k])) self.err_field_ctr += 1 LOG.info( "Total fields so far [%s], error fields [%s]", self.total_field_ctr, self.err_field_ctr) continue col_info['data_size'] += size fields[k]['size_histogram'].put(size) fields[k]['distinct_values'].add(v) fields[k]['num_values'] += 1 if parent_col: fields[k]['parent_col'] = parent_col ## FOR else: if self.debug: LOG.debug("Extracting keys in nested field for '%s'" % k) if not 'fields' in fields[k]: fields[k]['fields'] = {} self.processDataFields(col_info, fields[k]['fields'], doc[k]) ## ---------------------------------------------- ## LIST OF VALUES ## Could be either scalars or dicts. If it's a dict, then we'll just ## store the nested field information in the 'fields' value ## If it's a list, then we'll use a special marker 'LIST_INNER_FIELD' to ## store the field information for the inner values. ## ---------------------------------------------- elif isinstance(v, list): if self.debug: LOG.debug("Extracting keys in nested list for '%s'" % k) if not 'fields' in fields[k]: fields[k]['fields'] = {} list_len = len(doc[k]) fields[k]['list_len'].put(list_len) for i in xrange(list_len): inner_type = type(doc[k][i]) # More nested documents... if inner_type == dict: if self.debug: LOG.debug( "Extracting keys in nested field in list position %d for '%s'" % (i, k)) self.processDataFields(col_info, fields[k]['fields'], doc[k][i]) else: # TODO: We probably should store a list of types here in case # the list has different types of values inner = fields[k]['fields'].get( constants.LIST_INNER_FIELD, {}) inner['type'] = catalog.fieldTypeToString(inner_type) try: inner_size = catalog.getEstimatedSize( inner['type'], doc[k][i]) self.total_field_ctr += 1 except: if self.debug: LOG.error("Failed to estimate size for list entry #%d for field '%s' in collection '%s'\n%s",\ i, k, col_info['name'], pformat(fields[k])) self.err_field_ctr += 1 LOG.info( "Total fields so far [%s], error fields [%s]", self.total_field_ctr, self.err_field_ctr) continue fields[k]['fields'][constants.LIST_INNER_FIELD] = inner fields[k]['size_histogram'].put(inner_size) fields[k]['distinct_values'].add(doc[k][i]) fields[k]['num_values'] += 1 if parent_col: fields[k]['parent_col'] = parent_col ## FOR (list) ## ---------------------------------------------- ## SCALAR VALUES ## ---------------------------------------------- else: try: size = catalog.getEstimatedSize(fields[k]['type'], v) self.total_field_ctr += 1 except: LOG.error("Failed to estimate size for field %s in collection %s\n%s",\ k, col_info['name'], pformat(fields[k])) self.err_field_ctr += 1 LOG.info("Total fields so far [%s], error fields [%s]", self.total_field_ctr, self.err_field_ctr) continue col_info['data_size'] += size fields[k]['size_histogram'].put(size) fields[k]['distinct_values'].add(v) fields[k]['num_values'] += 1 if parent_col: fields[k]['parent_col'] = parent_col
class State: """Cost Model State""" ## ----------------------------------------------------------------------- ## INTERNAL CACHE STATE ## ----------------------------------------------------------------------- class Cache: """ Internal cache for a single collection. Note that this is different than the LRUBuffer cache stuff. These are cached look-ups that the CostModel uses for figuring out what operations do. """ def __init__(self, col_info, num_nodes): # The number of pages needed to do a full scan of this collection # The worst case for all other operations is if we have to do # a full scan that requires us to evict the entire buffer # Hence, we multiple the max pages by two # self.fullscan_pages = (col_info['max_pages'] * 2) self.fullscan_pages = col_info["doc_count"] * 2 assert self.fullscan_pages > 0, "Zero max_pages for collection '%s'" % col_info["name"] # Cache of Best Index Tuples # QueryHash -> BestIndex self.best_index = {} # Cache of Regex Operations # QueryHash -> Boolean self.op_regex = {} # Cache of Touched Node Ids # QueryId -> [NodeId] self.op_nodeIds = {} # Cache of Document Ids # QueryId -> Index/Collection DocumentIds self.collection_docIds = {} self.index_docIds = {} ## DEF def reset(self): self.best_index.clear() self.op_regex.clear() self.op_nodeIds.clear() self.collection_docIds.clear() self.index_docIds.clear() self.op_count = 0 self.msg_count = 0 self.network_reset = True ## DEF def __str__(self): ret = "" max_len = max(map(len, self.__dict__.iterkeys())) + 1 f = " %-" + str(max_len) + "s %s\n" for k, v in self.__dict__.iteritems(): if isinstance(v, dict): v_str = "[%d entries]" % len(v) else: v_str = str(v) ret += f % (k + ":", v_str) return ret ## DEF ## CLASS def __init__(self, collections, workload, config): assert isinstance(collections, dict) # LOG.setLevel(logging.DEBUG) self.debug = LOG.isEnabledFor(logging.DEBUG) self.collections = collections self.col_names = [col_name for col_name in collections.iterkeys()] self.workload = None # working workload self.originalWorload = workload # points to the original workload self.weight_network = config.get("weight_network", 1.0) self.weight_disk = config.get("weight_disk", 1.0) self.weight_skew = config.get("weight_skew", 1.0) self.max_num_nodes = config.get("nodes", 1) # Convert MB to bytes self.max_memory = config["max_memory"] * 1024 * 1024 self.skew_segments = config["skew_intervals"] # Why? "- 1" self.address_size = config["address_size"] / 4 self.estimator = NodeEstimator(collections, self.max_num_nodes) self.window_size = config["window_size"] # Build indexes from collections to sessions/operations # Note that this won't change dynamically based on denormalization schemes # It's up to the cost components to figure things out based on that self.restoreOriginalWorkload() # We need to know the number of operations in the original workload # so that all of our calculations are based on that self.orig_op_count = 0 for sess in self.originalWorload: self.orig_op_count += len(sess["operations"]) ## FOR ## ---------------------------------------------- ## CACHING ## ---------------------------------------------- self.cache_enable = True self.cache_miss_ctr = Histogram() self.cache_hit_ctr = Histogram() # ColName -> CacheHandle self.cache_handles = {} ## DEF def init_xref(self, workload): """ initialize the cross reference based on the current working workload """ self.col_sess_xref = dict([(col_name, []) for col_name in self.col_names]) self.col_op_xref = dict([(col_name, []) for col_name in self.col_names]) self.__buildCrossReference__(workload) ## DEF def updateWorkload(self, workload): self.workload = workload self.init_xref(workload) ## DEF def restoreOriginalWorkload(self): self.workload = self.originalWorload self.init_xref(self.workload) ## DEF def __buildCrossReference__(self, workload): for sess in workload: cols = set() for op in sess["operations"]: col_name = op["collection"] if col_name in self.col_sess_xref: self.col_op_xref[col_name].append(op) cols.add(col_name) ## FOR (op) for col_name in cols: self.col_sess_xref[col_name].append(sess) ## FOR (sess) def invalidateCache(self, col_name): if col_name in self.cache_handles: if self.debug: LOG.debug("Invalidating cache for collection '%s'", col_name) self.cache_handles[col_name].reset() ## DEF def getCacheHandleByName(self, col_info): """ Return a cache handle for the given collection name. This is the preferrred method because it requires fewer hashes """ cache = self.cache_handles.get(col_info["name"], None) if cache is None: cache = State.Cache(col_info, self.max_num_nodes) self.cache_handles[col_info["name"]] = cache return cache ## DEF def getCacheHandle(self, col_info): return self.getCacheHandleByName(col_info) ## DEF def reset(self): """ Reset all of the internal state and cache information """ # Clear out caches for all collections self.cache_handles.clear() self.estimator.reset() def calcNumNodes(self, design, maxCardinality): num_nodes = {} for col_name in self.collections.keys(): num_nodes[col_name] = self.max_num_nodes if maxCardinality[col_name] is not None and design.hasCollection(col_name): cardinality = 1 shard_keys = design.getShardKeys(col_name) if shard_keys is None or len(shard_keys) == 0: continue for shard_key in shard_keys: if (not self.collections[col_name]["fields"].has_key(shard_key)) or ( not self.collections[col_name]["fields"][shard_key].has_key("cardinality") ): continue field_cardinality = self.collections[col_name]["fields"][shard_key]["cardinality"] if field_cardinality > 0: cardinality *= field_cardinality cardinality_ratio = maxCardinality[col_name] / float(cardinality) if cardinality_ratio == 1: cardinality_ratio = 0 elif cardinality_ratio < 2: cardinality_ratio = 1 else: cardinality_ratio = int(math.ceil(math.log(cardinality_ratio, 2))) col_num_nodes = self.max_num_nodes - cardinality_ratio if col_num_nodes <= 0: col_num_nodes = 1 num_nodes[col_name] = col_num_nodes return num_nodes ## ----------------------------------------------------------------------- ## UTILITY CODE ## ----------------------------------------------------------------------- def __getIsOpRegex__(self, cache, op): isRegex = cache.op_regex.get(op["query_hash"], None) if isRegex is None: isRegex = workload.isOpRegex(op) if self.cache_enable: if self.debug: self.cache_miss_ctr.put("op_regex") cache.op_regex[op["query_hash"]] = isRegex elif self.debug: self.cache_hit_ctr.put("op_regex") return isRegex ## DEF def __getNodeIds__(self, cache, design, op, num_nodes=None): node_ids = cache.op_nodeIds.get(op["query_id"], None) if node_ids is None: try: node_ids = self.estimator.estimateNodes(design, op, num_nodes) except: if self.debug: LOG.error("Failed to estimate touched nodes for op #%d\n%s", op["query_id"], pformat(op)) raise if self.cache_enable: if self.debug: self.cache_miss_ctr.put("op_nodeIds") cache.op_nodeIds[op["query_id"]] = node_ids if self.debug: LOG.debug("Estimated Touched Nodes for Op #%d: %d", op["query_id"], len(node_ids)) elif self.debug: self.cache_hit_ctr.put("op_nodeIds") return node_ids
class NodeEstimator(object): def __init__(self, collections, max_num_nodes): assert isinstance(collections, dict) # LOG.setLevel(logging.DEBUG) self.debug = LOG.isEnabledFor(logging.DEBUG) self.collections = collections self.max_num_nodes = max_num_nodes # Keep track of how many times that we accessed each node self.nodeCounts = Histogram() self.op_count = 0 ## DEF def reset(self): """ Reset internal counters for this estimator. This should be called everytime we start evaluating a new design """ self.nodeCounts.clear() self.op_count = 0 ## DEF def colNumNodes(self, num_nodes, col_name): if num_nodes is None or not num_nodes.has_key(col_name): return self.max_num_nodes return num_nodes[col_name] def estimateNodes(self, design, op, num_nodes=None): """ For the given operation and a design object, return an estimate of a list of node ids that we think that the query will be executed on """ results = set() broadcast = True shardingKeys = design.getShardKeys(op['collection']) if self.debug: LOG.debug("Computing node estimate for Op #%d [sharding=%s]", \ op['query_id'], shardingKeys) # If there are no sharding keys # All requests on this collection will be routed to the primary node # We assume the node 0 is the primary node if len(shardingKeys) == 0: broadcast = False results.add(0) # Inserts always go to a single node elif op['type'] == constants.OP_TYPE_INSERT: # Get the documents that they're trying to insert and then # compute their hashes based on the sharding key # Because there is no logical replication, each document will # be inserted in one and only one node for content in workload.getOpContents(op): values = catalog.getFieldValues(shardingKeys, content) results.add(self.computeTouchedNode(op['collection'], shardingKeys, values, num_nodes)) ## FOR broadcast = False # Network costs of SELECT, UPDATE, DELETE queries are based off # of using the sharding key in the predicate elif len(op['predicates']) > 0: predicate_fields = set() predicate_types = set() for k,v in op['predicates'].iteritems() : if design.inShardKeyPattern(op['collection'], k): predicate_fields.add(k) predicate_types.add(v) if len(predicate_fields) == len(shardingKeys): broadcast = False if self.debug: LOG.debug("Op #%d %s Predicates: %s [broadcast=%s / predicateTypes=%s]",\ op['query_id'], op['collection'], op['predicates'], broadcast, list(predicate_types)) ## ---------------------------------------------- ## PRED_TYPE_REGEX ## ---------------------------------------------- if not broadcast and constants.PRED_TYPE_REGEX in predicate_types: # Any query that is using a regex on the sharding key must be broadcast to every node # It's not complete accurate but it's just easier that way broadcast = True ## ---------------------------------------------- ## PRED_TYPE_RANGE ## ---------------------------------------------- elif not broadcast and constants.PRED_TYPE_RANGE in predicate_types: broadcast = True ## ---------------------------------------------- ## PRED_TYPE_EQUALITY ## ---------------------------------------------- elif not broadcast and constants.PRED_TYPE_EQUALITY in predicate_types: broadcast = False for content in workload.getOpContents(op): values = catalog.getFieldValues(shardingKeys, content) results.add(self.computeTouchedNode(op['collection'], shardingKeys, values, num_nodes)) ## FOR ## ---------------------------------------------- ## BUSTED! ## ---------------------------------------------- elif not broadcast: raise Exception("Unexpected predicate types '%s' for op #%d" % (list(predicate_types), op['query_id'])) ## IF if broadcast: if self.debug: LOG.debug("Op #%d on '%s' is a broadcast query to all nodes",\ op["query_id"], op["collection"]) map(results.add, xrange(0, self.colNumNodes(num_nodes, op["collection"]))) map(self.nodeCounts.put, results) self.op_count += 1 return results ## DEF def computeTouchedNode(self, col_name, fields, values, num_nodes=None): if len(values) != len(fields): return 0 fieldsTuple = [] fieldsToCalc = [] valuesToCalc = [] for i in range(len(fields)): fieldsTuple.append((fields[i], values[i], self.collections[col_name]["fields"][fields[i]]["cardinality"])) fieldsTuple = sorted(fieldsTuple, key=lambda field: field[2], reverse=True) cardinality = 1 for fieldTuple in fieldsTuple: cardinality *= fieldTuple[2] fieldsToCalc.append(fieldTuple[0]) valuesToCalc.append(fieldTuple[1]) if cardinality >= self.max_num_nodes: break return self.computeTouchedNodeImpl(col_name, fieldsToCalc, valuesToCalc, num_nodes) def computeTouchedNodeImpl(self, col_name, fields, values, num_nodes=None): index = 0 factor = 1 for i in range(len(fields)): index += (self.computeTouchedRange(col_name, fields[i], values[i], num_nodes) * factor) factor *= self.max_num_nodes index /= math.pow(self.max_num_nodes, len(fields) - 1) return int(math.floor(index * self.colNumNodes(num_nodes, col_name) / float(self.max_num_nodes))) ## DEF def computeTouchedRange(self, col_name, field_name, value, num_nodes=None): ranges = self.collections[col_name]['fields'][field_name]['ranges'] if len(ranges) == 0: return hash(str(value)) % self.max_num_nodes index = 0 while index < len(ranges): if index == len(ranges) - 1: return index % self.max_num_nodes if self.inRange(value, ranges[index], ranges[index + 1]): return index % self.max_num_nodes index += 1 return index % self.max_num_nodes def inRange(self, value, start, end): try: if isinstance(value, list): value = "%s-%s-%s" % (value[0], value[1], value[2]) return str(start) <= value < str(end) return start <= value < end except: return True def guessNodes(self, design, colName, fieldName, num_nodes=None): """ Return the number of nodes that a query accessing a collection using the given field will touch. This serves as a stand-in for the EXPLAIN function referenced in the paper """ col_info = self.collections[colName] if not fieldName in col_info['fields']: raise Exception("Invalid field '%s.%s" % (colName, fieldName)) field = col_info['fields'][fieldName] # TODO: How do we use the statistics to determine the selectivity of this particular # attribute and thus determine the number of nodes required to answer the query? return int(math.ceil(field['selectivity'] * self.colNumNodes(num_nodes, colName))) ## DEF def getOpCount(self): """Return the number of operations evaluated""" return self.op_count ## CLASS