def eval(self, t, field_names, ctx): """Evaluates the expression using the given tuple, the names of the fields in the tuple and the aggregate functions context (which holds any variables and the running result). :param t: Tuple to evaluate :param field_names: Names of the tuple fields :param ctx: The aggregate context :return: None """ if self.__expression_type is AggregateExpression.SUM: sum_fn(self.__expr(IndexedTuple.build(t, field_names)), ctx) elif self.__expression_type is AggregateExpression.COUNT: count_fn(self.__expr(IndexedTuple.build(t, field_names)), ctx) elif self.__expression_type is AggregateExpression.AVG: avg_fn(self.__expr(IndexedTuple.build(t, field_names)), ctx) else: # Should not happen as its already been checked raise Exception( "Illegal expression type '{}'. Expression type must be '{}', '{}', or '{}'" .format(self.__expression_type, AggregateExpression.SUM, AggregateExpression.COUNT, AggregateExpression.AVG)) # self.__expr(LabelledTuple(t, field_names), ctx) if not isinstance(ctx.result, numbers.Number): raise Exception( "Illegal aggregate val '{}' of type '{}'. Aggregate expression must evaluate to number" .format(ctx.result, type(ctx.result)))
def test_labelled_tuple(): field_names = ['one', 'two'] t1 = IndexedTuple.build_default(['A', 'B']) assert t1['_0'] == 'A' assert t1['_1'] == 'B' t2 = IndexedTuple.build(['A', 'B'], field_names) assert t2['one'] == 'A' assert t2['two'] == 'B'
def __on_receive_tuple(self, tuple_, producer_name): """Event handler for a received tuple :param tuple_: The received tuple :return: None """ if self.field_names is None: self.field_names = tuple_ self.send(TupleMessage(tuple_), self.consumers) self.producers_received[producer_name] = True else: if producer_name not in self.producers_received.keys(): # Will be field names, skip self.producers_received[producer_name] = True else: it = IndexedTuple.build(tuple_, self.field_names) idx = int(it[self.map_field_name]) % len(self.consumers) self.op_metrics.rows_mapped += 1 self.send(TupleMessage(tuple_), [self.consumers[idx]])
def execute_py_query(op): cur = Cursor(op.s3).select(op.s3key, op.s3sql) tuples = cur.execute() first_tuple = True for t in tuples: if op.is_completed(): break op.op_metrics.rows_returned += 1 if first_tuple: # Create and send the record field names it = IndexedTuple.build_default(t) first_tuple = False if op.log_enabled: print("{}('{}') | Sending field names: {}".format( op.__class__.__name__, op.name, it.field_names())) op.send(TupleMessage(Tuple(it.field_names())), op.consumers) # if op.log_enabled: # print("{}('{}') | Sending field values: {}".format(op.__class__.__name__, op.name, t)) op.send(TupleMessage(Tuple(t)), op.consumers) return cur
def __build_field_names(self): """Creates the list of field names from the evaluated aggregates. Field names will just be _0, _1, etc. :return: The list of field names. """ return IndexedTuple.build_default(self.__expressions).field_names()
def test_join_topk(): """Tests a top k with a join :return: None """ limit = 5 query_plan = QueryPlan() # Query plan ts1 = query_plan.add_operator(SQLTableScan('supplier.csv', 'select * from S3Object;', False, 'ts1', query_plan, False)) ts1_project = query_plan.add_operator( Project([ProjectExpression(lambda t_: t_['_3'], 's_nationkey')], 'ts1_project', query_plan, False)) ts2 = query_plan.add_operator(SQLTableScan('nation.csv', 'select * from S3Object;', False, 'ts2', query_plan, False)) ts2_project = query_plan.add_operator( Project([ProjectExpression(lambda t_: t_['_0'], 'n_nationkey')], 'ts2_project', query_plan, False)) j = query_plan.add_operator(HashJoin(JoinExpression('s_nationkey', 'n_nationkey'), 'j', query_plan, False)) t = query_plan.add_operator(Limit(limit, 't', query_plan, False)) c = query_plan.add_operator(Collate('c', query_plan, False)) ts1.connect(ts1_project) ts2.connect(ts2_project) j.connect_left_producer(ts1_project) j.connect_right_producer(ts2_project) j.connect(t) t.connect(c) # Write the plan graph query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id()) # Start the query query_plan.execute() # Assert the results # num_rows = 0 # for t in c.tuples(): # num_rows += 1 # print("{}:{}".format(num_rows, t)) c.print_tuples() field_names = ['s_nationkey', 'n_nationkey'] assert len(c.tuples()) == limit + 1 assert c.tuples()[0] == field_names num_rows = 0 for t in c.tuples(): num_rows += 1 # Assert that the nation_key in table 1 has been joined with the record in table 2 with the same nation_key if num_rows > 1: lt = IndexedTuple.build(t, field_names) assert lt['s_nationkey'] == lt['n_nationkey'] # Write the metrics query_plan.print_metrics()
def start(self): self.op_metrics.timer_start() it = IndexedTuple.build_default(self.col_defs) if self.log_enabled: print("{}('{}') | Sending field names: {}".format( self.__class__.__name__, self.name, it.field_names())) self.send(TupleMessage(Tuple(it.field_names())), self.consumers) for i in range(0, self.num_rows): if self.is_completed(): break self.op_metrics.rows_returned += 1 t = Tuple() for col_def in self.col_defs: col_val = col_def.generate() t.append(col_val) if self.log_enabled: print("{}('{}') | Sending field values: {}".format( self.__class__.__name__, self.name, t)) self.send(TupleMessage(t), self.consumers) if not self.is_completed(): self.complete() self.op_metrics.timer_stop()
def test_group_count(): """Tests a group by query with a count aggregate :return: None """ num_rows = 0 query_plan = QueryPlan() # Query plan # select s_nationkey, count(s_suppkey) from supplier.csv group by s_nationkey ts = query_plan.add_operator( SQLTableScan('supplier.csv', 'select * from S3Object;', False, 'ts', query_plan, False)) g = query_plan.add_operator( Group( ['_3'], [ AggregateExpression(AggregateExpression.COUNT, lambda t_: t_['_0']) # count(s_suppkey) ], 'g', query_plan, False)) c = query_plan.add_operator(Collate('c', query_plan, False)) query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id()) ts.connect(g) g.connect(c) # Start the query query_plan.execute() # Assert the results for _ in c.tuples(): num_rows += 1 # print("{}:{}".format(num_rows, t)) field_names = ['_0', '_1'] assert c.tuples()[0] == field_names assert len(c.tuples()) == 25 + 1 nation_24 = filter( lambda t: IndexedTuple.build(t, field_names)['_0'] == '24', c.tuples())[0] assert nation_24[1] == 393 assert num_rows == 25 + 1 # Write the metrics query_plan.print_metrics()
def on_producer_completed(self, producer_name): """Handles the event where the producer has completed producing all the tuples it will produce. Once this occurs the tuples can be sent to consumers downstream. :param producer_name: The producer that has completed :return: None """ if producer_name in self.producer_completions.keys(): self.producer_completions[producer_name] = True else: raise Exception( "Unrecognized producer {} has completed".format(producer_name)) is_all_producers_done = all(self.producer_completions.values()) if not is_all_producers_done: return if not self.use_pandas: # Send the field names lt = IndexedTuple.build_default(self.group_field_names + self.aggregate_expressions) self.send(TupleMessage(Tuple(lt.field_names())), self.consumers) for group_tuple, group_aggregate_contexts in self.group_contexts.items( ): if self.is_completed(): break # Convert the aggregate contexts to their results group_fields = list(group_tuple) group_aggregate_values = list( v.result for v in group_aggregate_contexts.values()) t_ = group_fields + group_aggregate_values self.send(TupleMessage(Tuple(t_)), self.consumers) else: # for groupby_reducer, aggregate one more time. if not self.is_completed() and len(self.producers) > 1: self.aggregate_df = self.pd_expr(self.aggregate_df) if not self.is_completed() and self.aggregate_df is not None: self.aggregate_df.reset_index(drop=True, inplace=True) # if self.log_enabled: # with pd.option_context('display.max_rows', None, 'display.max_columns', None): # print("{}('{}') | Sending grouped field values: \n{}" # .format(self.__class__.__name__, self.name, self.aggregate_df)) #self.send(TupleMessage(Tuple(list(self.aggregate_df))), self.consumers) self.send(DataFrameMessage(self.aggregate_df), self.consumers) del self.aggregate_df Operator.on_producer_completed(self, producer_name)
def test_aggregate_count(): """Tests a group by query with a count aggregate :return: None """ query_plan = QueryPlan() # Query plan # select count(*) from supplier.csv ts = query_plan.add_operator( SQLTableScan('supplier.csv', 'select * from S3Object;', False, 'ts', query_plan, False)) a = query_plan.add_operator( Aggregate( [ AggregateExpression(AggregateExpression.COUNT, lambda t_: t_['_0']) # count(s_suppkey) ], 'a', query_plan, False)) c = query_plan.add_operator(Collate('c', query_plan, False)) ts.connect(a) a.connect(c) # Write the plan graph query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id()) # Start the query query_plan.execute() # Assert the results # num_rows = 0 # for t in c.tuples(): # num_rows += 1 # print("{}:{}".format(num_rows, t)) c.print_tuples() field_names = ['_0'] tuples = c.tuples() assert tuples[0] == field_names assert IndexedTuple.build(tuples[1], field_names)['_0'] == 10000 assert len(tuples) == 1 + 1 # Write the metrics query_plan.print_metrics()
def on_receive_tuple(self, tuple_, _producer_name): if not self.field_names_index: self.field_names_index = IndexedTuple.build_field_names_index(tuple_) self.send(TupleMessage(tuple_), self.consumers) self.producers_received[_producer_name] = True else: if _producer_name not in self.producers_received.keys(): # Will be field names, skip self.producers_received[_producer_name] = True else: if self.hashtable is None: self.hashtable = {} self.op_metrics.rows_processed += 1 it = IndexedTuple(tuple_, self.field_names_index) itd = self.hashtable.setdefault(it[self.key], []) itd.append(tuple_)
def send_field_names(self, tuple_): """Sends the field names tuple :param tuple_: The tuple :return: None """ # Create and send the record field names lt = IndexedTuple.build_default(tuple_) labels = Tuple(lt.field_names()) if self.log_enabled: print("{}('{}') | Sending field names [{}]".format( self.__class__.__name__, self.name, {'field_names': labels})) self.send(TupleMessage(labels), self.consumers)
def test_group_sum(): """Tests a group by query with a sum aggregate :return: None """ query_plan = QueryPlan() # Query plan # select s_nationkey, sum(float(s_acctbal)) from supplier.csv group by s_nationkey ts = query_plan.add_operator( SQLTableScan('supplier.csv', 'select * from S3Object;', False, 'ts', query_plan, False)) g = query_plan.add_operator( Group(['_3'], [ AggregateExpression(AggregateExpression.SUM, lambda t_: float(t_['_5'])) ], 'g', query_plan, False)) c = query_plan.add_operator(Collate('c', query_plan, False)) ts.connect(g) g.connect(c) # Start the query query_plan.execute() # Assert the results # num_rows = 0 # for t in c.tuples(): # num_rows += 1 # print("{}:{}".format(num_rows, t_)) field_names = ['_0', '_1'] assert c.tuples()[0] == field_names assert len(c.tuples()) == 25 + 1 nation_24 = filter( lambda t_: IndexedTuple.build(t_, field_names)['_0'] == '24', c.tuples())[0] assert round(nation_24[1], 2) == 1833872.56 # Write the metrics query_plan.print_metrics()
def test_aggregate_sum(): """Tests a group by query with a sum aggregate :return: None """ query_plan = QueryPlan() # Query plan # select sum(float(s_acctbal)) from supplier.csv ts = query_plan.add_operator( SQLTableScan('supplier.csv', 'select * from S3Object;', False, 'ts', query_plan, False)) a = query_plan.add_operator( Aggregate([ AggregateExpression(AggregateExpression.SUM, lambda t_: float(t_['_5'])) ], 'a', query_plan, False)) c = query_plan.add_operator(Collate('c', query_plan, False)) ts.connect(a) a.connect(c) # Write the plan graph query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id()) # Start the query query_plan.execute() # Assert the results # num_rows = 0 # for t in c.tuples(): # num_rows += 1 # print("{}:{}".format(num_rows, t)) field_names = ['_0'] assert c.tuples()[0] == field_names assert round(IndexedTuple.build(c.tuples()[1], field_names)['_0'], 2) == 45103548.65 assert len(c.tuples()) == 1 + 1 # Write the metrics query_plan.print_metrics()
def on_producer_completed(self, producer_name): """Event handler for a completed producer. When producers complete the bloom filter can be sent. :param producer_name: The producer that completed. :return: None """ self.producer_completions[producer_name] = True if all(self.producer_completions.values()): # Get the SQL from a bloom use operators bloom_use_operators = filter( lambda o: isinstance(o, SQLTableScanBloomUse), self.consumers) bloom_use_sql_strings = map(lambda o: o.s3sql, bloom_use_operators) max_bloom_use_sql_strings = max( map(lambda s: len(s), bloom_use_sql_strings)) # Build bloom filter best_possible_fp_rate = SlicedSQLBloomFilter.calc_best_fp_rate( len(self.__tuples), max_bloom_use_sql_strings) if best_possible_fp_rate > self.fp_rate: print("{}('{}') | Bloom filter fp rate ({}) too low, " "will exceed max S3 Select SQL expression length ({}). " "Raising to best possible ({})".format( self.__class__.__name__, self.name, self.fp_rate, MAX_S3_SELECT_EXPRESSION_LEN, best_possible_fp_rate)) fp_rate_to_use = best_possible_fp_rate else: fp_rate_to_use = self.fp_rate bloom_filter = self.build_bloom_filter(len(self.__tuples), fp_rate_to_use) for t in self.__tuples: lt = IndexedTuple.build(t, self.__field_names) bloom_filter.add(int(lt[self.bloom_field_name])) del self.__tuples # Send the bloom filter self.__send_bloom_filter(bloom_filter) Operator.on_producer_completed(self, producer_name)
def on_receive_tuple(self, tuple_, producer_name): # Check the producer is connected if self.build_producers is []: raise Exception("Left producers are not connected") if self.tuple_producers is []: raise Exception("Right producer is not connected") # Check which producer sent the tuple if producer_name in self.build_producers: if self.build_field_names is None: if self.join_expr.l_field in tuple_: self.build_field_names = tuple_ self.field_names_index = IndexedTuple.build_field_names_index( tuple_) else: raise Exception( "Join Operator '{}' received invalid left field names tuple {}. " "Tuple must contain join left field name '{}'.".format( self.name, tuple_, self.join_expr.l_field)) elif producer_name in self.tuple_producers: if self.tuple_field_names is None: if self.join_expr.r_field in tuple_: self.tuple_field_names = tuple_ else: raise Exception( "Join Operator '{}' received invalid right field names tuple {}. " "Tuple must contain join right field name '{}'.". format(self.name, tuple_, self.join_expr.r_field)) else: self.op_metrics.tuple_rows_processed += 1 self.tuples.append(tuple_) else: raise Exception( "Join Operator '{}' received invalid tuple {} from producer '{}'. " "Tuple must be sent from connected left producer '{}' or right producer '{}'." .format(self.name, tuple_, producer_name, self.build_producers, self.tuple_producers))
def start(self): self.op_metrics.timer_start() if self.parts == 1: self.records = [] self.worker_metrics = {} self.download_part(0, self.records, self.worker_metrics) else: processes = [] for part in range(self.parts): p = Process(target=self.download_part, args=(part, self.records, self.worker_metrics)) p.start() processes.append(p) for p in processes: p.join() print("All parts finished with {} records".format(len(self.records))) first_tuple = True for msg in self.records: if first_tuple: # Create and send the record field names it = IndexedTuple.build_default(msg.tuple_) first_tuple = False if self.log_enabled: print("{}('{}') | Sending field names: {}" .format(self.__class__.__name__, self.name, it.field_names())) self.send(TupleMessage(Tuple(it.field_names())), self.consumers) self.send(msg, self.consumers) self.complete() self.op_metrics.timer_stop() self.print_stats(to_file=self.s3key + '.' + str(self.parts) +'.stats.txt') self.records[:] = []
def __on_receive_tuple(self, tuple_, producer_name): """Event handler to handle receipt of a tuple :param tuple_: The tuple :return: None """ assert (len(tuple_) > 0) if not self.field_names_index: self.field_names_index = IndexedTuple.build_field_names_index( tuple_) self.producers_received[producer_name] = True self.__send_field_names(tuple_) else: if producer_name not in self.producers_received.keys(): # This will be the field names tuple, skip it self.producers_received[producer_name] = True else: if self.__evaluate_filter(tuple_): self.op_metrics.rows_filtered += 1 self.__send_field_values(tuple_)
def test_aggregate(): """Executes a select with an aggregate. :return: None """ num_rows = 0 cur = Cursor(boto3.client('s3')) \ .select('region.csv', 'select count(*) from S3Object') try: rows = cur.execute() for r in rows: num_rows += 1 lt = IndexedTuple.build_default(r) assert lt['_0'] == '5' # print("{}:{}".format(num_rows, r)) assert num_rows == 1 finally: cur.close()
def test_where_predicate(): """Executes a select with a where clause on one of the attributes. :return: None """ num_rows = 0 cur = Cursor(boto3.client('s3'))\ .select('region.csv', 'select * from S3Object where r_name = \'AMERICA\';') try: rows = cur.execute() for r in rows: num_rows += 1 lt = IndexedTuple.build_default(r) assert lt['_1'] == 'AMERICA' # print("{}:{}".format(num_rows, r)) assert num_rows == 1 finally: cur.close()
def eval(self, tuple_, field_names_index): """Evaluates the predicate using the given tuple :param tuple_: The tuple to evaluate the expression against :param field_names_index: The names of the fields in the tuple :return: True or false """ if self.expr: it = IndexedTuple(tuple_, field_names_index) v = self.expr(it) if type(v) is not bool and type(v) is not numpy.bool_: raise Exception( "Illegal return type '{}'. " "Predicate expression must evaluate to {} or {}".format( type(v), bool, numpy.bool_)) return v else: raise NotImplementedError
def on_receive_tuple(self, tuple_, producer_name): """Handles the receipt of a tuple. The tuple is mapped to a new tuple using the given projection expressions. The field names are modified according to the new field names in the projection expressions. :param producer_name: :param tuple_: The received tuple :return: None """ assert (len(tuple_) > 0) if not self.field_names_index: self.field_names_index = IndexedTuple.build_field_names_index( tuple_) # Map the old field names to the new projected_field_names = [] for e in self.project_exprs: fn = e.new_field_name projected_field_names.append(fn) if self.log_enabled: print( "{}('{}') | Sending projected field names: from: {} to: {}" .format(self.__class__.__name__, self.name, tuple_, projected_field_names)) self.producers_received[producer_name] = True assert (len(projected_field_names) == len(self.project_exprs)) self.send(TupleMessage(Tuple(projected_field_names)), self.consumers) else: assert (len(tuple_) == len(self.field_names_index)) if producer_name not in self.producers_received.keys(): # This will be the field names tuple, skip it self.producers_received[producer_name] = True else: # Perform the projection using the given expressions it = IndexedTuple(tuple_, self.field_names_index) projected_field_values = [] for e in self.project_exprs: fv = e.expr(it) projected_field_values.append(fv) self.op_metrics.rows_projected += 1 if self.log_enabled: print( "{}('{}') | Sending projected field values: from: {} to: {}" .format(self.__class__.__name__, self.name, tuple_, projected_field_values)) assert (len(projected_field_values) == len(self.project_exprs)) self.send(TupleMessage(Tuple(projected_field_values)), self.consumers)
def test_sort_topk(): """Executes a sorted top k query, which must use the top operator as record sorting can't be pushed into s3. The results are collated. :return: None """ limit = 5 query_plan = QueryPlan() # Query plan ts = query_plan.add_operator(SQLTableScan('supplier.csv', 'select * from S3Object;', False, 'ts', query_plan, False)) s = query_plan.add_operator(Sort([ SortExpression('_5', float, 'ASC') ], 's', query_plan, False)) t = query_plan.add_operator(Limit(limit, 't', query_plan, False)) c = query_plan.add_operator(Collate('c', query_plan, False)) ts.connect(s) s.connect(t) t.connect(c) # Write the plan graph query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id()) # Start the query query_plan.execute() # Assert the results # num_rows = 0 # for t in c.tuples(): # num_rows += 1 # print("{}:{}".format(num_rows, t)) field_names = ['_0', '_1', '_2', '_3', '_4', '_5', '_6'] assert len(c.tuples()) == limit + 1 assert c.tuples()[0] == field_names prev = None num_rows = 0 for t in c.tuples(): num_rows += 1 # print("{}:{}".format(num_rows, t)) if num_rows > 1: if prev is None: prev = t else: lt = IndexedTuple.build(t, field_names) prev_lt = IndexedTuple.build(prev, field_names) assert float(lt['_5']) > float(prev_lt['_5']) # Write the metrics query_plan.print_metrics()
def run(parallel, buffer_size): """Tests a with sharded operators and separate build and probe for join :return: None """ query_plan = QueryPlan(is_async=parallel, buffer_size=buffer_size) # Query plan parts = 2 collate = query_plan.add_operator(Collate('collate', query_plan, False)) merge = query_plan.add_operator(Merge('merge', query_plan, False)) hash_join_build_ops = [] hash_join_probe_ops = [] for p in range(1, parts + 1): r_region_key_lower = math.ceil((5.0 / float(parts)) * (p - 1)) r_region_key_upper = math.ceil((5.0 / float(parts)) * p) region_scan = query_plan.add_operator( SQLTableScan('region.csv', 'select * ' 'from S3Object ' 'where cast(r_regionkey as int) >= {} and cast(r_regionkey as int) < {};' .format(r_region_key_lower, r_region_key_upper), False, 'region_scan' + '_' + str(p), query_plan, False)) region_project = query_plan.add_operator( Project([ ProjectExpression(lambda t_: t_['_0'], 'r_regionkey'), ProjectExpression(lambda t_: t_['_1'], 'r_name') ], 'region_project' + '_' + str(p), query_plan, False)) n_nation_key_lower = math.ceil((25.0 / float(parts)) * (p - 1)) n_nation_key_upper = math.ceil((25.0 / float(parts)) * p) nation_scan = query_plan.add_operator( SQLTableScan('nation.csv', 'select * from S3Object ' 'where cast(n_nationkey as int) >= {} and cast(n_nationkey as int) < {};' .format(n_nation_key_lower, n_nation_key_upper), False, 'nation_scan' + '_' + str(p), query_plan, False)) nation_project = query_plan.add_operator( Project([ ProjectExpression(lambda t_: t_['_0'], 'n_nationkey'), ProjectExpression(lambda t_: t_['_1'], 'n_name'), ProjectExpression(lambda t_: t_['_2'], 'n_regionkey') ], 'nation_project' + '_' + str(p), query_plan, False)) region_hash_join_build = query_plan.add_operator( HashJoinBuild('r_regionkey', 'region_hash_join_build' + '_' + str(p), query_plan, False)) hash_join_build_ops.append(region_hash_join_build) region_nation_join_probe = query_plan.add_operator( HashJoinProbe(JoinExpression('r_regionkey', 'n_regionkey'), 'region_nation_join_probe' + '_' + str(p), query_plan, False)) hash_join_probe_ops.append(region_nation_join_probe) region_scan.connect(region_project) nation_scan.connect(nation_project) region_project.connect(region_hash_join_build) region_nation_join_probe.connect_tuple_producer(nation_project) region_nation_join_probe.connect(merge) for probe_op in hash_join_probe_ops: for build_op in hash_join_build_ops: probe_op.connect_build_producer(build_op) merge.connect(collate) # Write the plan graph query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id()) # Start the query query_plan.execute() tuples = collate.tuples() collate.print_tuples(tuples) # Write the metrics query_plan.print_metrics() # Shut everything down query_plan.stop() field_names = ['r_regionkey', 'r_name', 'n_nationkey', 'n_name', 'n_regionkey'] assert len(tuples) == 25 + 1 assert tuples[0] == field_names num_rows = 0 for t in tuples: num_rows += 1 # Assert that the nation_key in table 1 has been joined with the record in table 2 with the same nation_key if num_rows > 1: lt = IndexedTuple.build(t, field_names) assert lt['r_regionkey'] == lt['n_regionkey']
def set_field_names(self, field_names): self.__field_name_index = IndexedTuple.build_field_names_index( field_names)
def test_join_baseline_pandas(): """Tests a join :return: None """ query_plan = QueryPlan(is_async=True, buffer_size=0) # Query plan supplier_scan = query_plan.add_operator( SQLTableScan('region.csv', 'select * from S3Object;', True, False, False, 'supplier_scan', query_plan, True)) def supplier_project_fn(df): df = df.filter(['_0'], axis='columns') df = df.rename(columns={'_0': 'r_regionkey'}) return df supplier_project = query_plan.add_operator( Project([ProjectExpression(lambda t_: t_['_0'], 'r_regionkey')], 'supplier_project', query_plan, True, supplier_project_fn)) nation_scan = query_plan.add_operator( SQLTableScan('nation.csv', 'select * from S3Object;', True, False, False, 'nation_scan', query_plan, True)) def nation_project_fn(df): df = df.filter(['_2'], axis='columns') df = df.rename(columns={'_2': 'n_regionkey'}) return df nation_project = query_plan.add_operator( Project([ProjectExpression(lambda t_: t_['_2'], 'n_regionkey')], 'nation_project', query_plan, True, nation_project_fn)) supplier_nation_join_build = query_plan.add_operator( HashJoinBuild('n_regionkey', 'supplier_nation_join_build', query_plan, True)) supplier_nation_join_probe = query_plan.add_operator( HashJoinProbe(JoinExpression('n_regionkey', 'r_regionkey'), 'supplier_nation_join_probe', query_plan, True)) collate = query_plan.add_operator(Collate('collate', query_plan, True)) supplier_scan.connect(supplier_project) nation_scan.connect(nation_project) nation_project.connect(supplier_nation_join_build) supplier_nation_join_probe.connect_build_producer(supplier_nation_join_build) supplier_nation_join_probe.connect_tuple_producer(supplier_project) supplier_nation_join_probe.connect(collate) # Write the plan graph query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id()) # Start the query query_plan.execute() tuples = collate.tuples() collate.print_tuples(tuples) # Write the metrics query_plan.print_metrics() # Shut everything down query_plan.stop() field_names = ['n_regionkey', 'r_regionkey'] assert len(tuples) == 25 + 1 assert tuples[0] == field_names num_rows = 0 for t in tuples: num_rows += 1 # Assert that the nation_key in table 1 has been joined with the record in table 2 with the same nation_key if num_rows > 1: lt = IndexedTuple.build(t, field_names) assert lt['n_regionkey'] == lt['r_regionkey']
def join_field_values(self): """Performs the join on data tuples using a nested loop joining algorithm. The joined tuples are each sent. Allows for the loop to be broken if the operator completes while executing. :return: None """ # Determine which direction the hash join should run # The larger relation should remain as a list and the smaller relation should be hashed. If either of the # relations are empty then just return if len(self.l_tuples) == 0 or len(self.r_tuples) == 0: return elif len(self.l_tuples) > len(self.r_tuples): l_to_r = True # r_to_l = not l_to_r else: l_to_r = False # r_to_l = not l_to_r if l_to_r: outer_tuples_list = self.l_tuples inner_tuples_list = self.r_tuples inner_tuple_field_name = self.join_expr.r_field inner_tuple_field_names = self.r_field_names outer_tuple_field_index = self.l_field_names.index( self.join_expr.l_field) else: outer_tuples_list = self.r_tuples inner_tuples_list = self.l_tuples inner_tuple_field_name = self.join_expr.l_field inner_tuple_field_names = self.l_field_names outer_tuple_field_index = self.r_field_names.index( self.join_expr.r_field) # Hash the tuples from the smaller set of tuples inner_tuples_dict = {} for t in inner_tuples_list: it = IndexedTuple.build(t, inner_tuple_field_names) itd = inner_tuples_dict.setdefault(it[inner_tuple_field_name], []) itd.append(t) for outer_tuple in outer_tuples_list: if self.is_completed(): break outer_tuple_field_value = outer_tuple[outer_tuple_field_index] inner_tuples = inner_tuples_dict.get(outer_tuple_field_value, None) # if self.log_enabled: # print("{}('{}') | Joining Outer: {} Inner: {}".format( # self.__class__.__name__, # self.name, # outer_tuple, # inner_tuples)) if inner_tuples is not None: for inner_tuple in inner_tuples: if l_to_r: t = outer_tuple + inner_tuple else: t = inner_tuple + outer_tuple # if self.log_enabled: # print("{}('{}') | Sending field values [{}]".format( # self.__class__.__name__, # self.name, # {'data': t})) self.op_metrics.rows_joined += 1 self.send(TupleMessage(Tuple(t)), self.consumers)
def test_r_to_l_join(): """Tests a join :return: None """ query_plan = QueryPlan() # Query plan supplier_scan = query_plan.add_operator( SQLTableScan('supplier.csv', 'select * from S3Object;', False, 'supplier_scan', query_plan, False)) supplier_project = query_plan.add_operator( Project([ProjectExpression(lambda t_: t_['_3'], 's_nationkey')], 'supplier_project', query_plan, False)) nation_scan = query_plan.add_operator( SQLTableScan('nation.csv', 'select * from S3Object;', False, 'nation_scan', query_plan, False)) nation_project = query_plan.add_operator( Project([ProjectExpression(lambda t_: t_['_0'], 'n_nationkey')], 'nation_project', query_plan, False)) supplier_nation_join = query_plan.add_operator( HashJoin(JoinExpression('n_nationkey', 's_nationkey'), 'supplier_nation_join', query_plan, False)) collate = query_plan.add_operator(Collate('collate', query_plan, False)) supplier_scan.connect(supplier_project) nation_scan.connect(nation_project) supplier_nation_join.connect_left_producer(nation_project) supplier_nation_join.connect_right_producer(supplier_project) supplier_nation_join.connect(collate) # Write the plan graph query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id()) # Start the query query_plan.execute() # Assert the results # num_rows = 0 # for t in collate.tuples(): # num_rows += 1 # print("{}:{}".format(num_rows, t)) # collate.print_tuples() field_names = ['n_nationkey', 's_nationkey'] assert len(collate.tuples()) == 10000 + 1 assert collate.tuples()[0] == field_names num_rows = 0 for t in collate.tuples(): num_rows += 1 # Assert that the nation_key in table 1 has been joined with the record in table 2 with the same nation_key if num_rows > 1: lt = IndexedTuple.build(t, field_names) assert lt['s_nationkey'] == lt['n_nationkey'] # Write the metrics query_plan.print_metrics()