예제 #1
0
    def eval(self, t, field_names, ctx):
        """Evaluates the expression using the given tuple, the names of the fields in the tuple and the aggregate
        functions context (which holds any variables and the running result).

        :param t: Tuple to evaluate
        :param field_names: Names of the tuple fields
        :param ctx: The aggregate context
        :return: None
        """

        if self.__expression_type is AggregateExpression.SUM:
            sum_fn(self.__expr(IndexedTuple.build(t, field_names)), ctx)
        elif self.__expression_type is AggregateExpression.COUNT:
            count_fn(self.__expr(IndexedTuple.build(t, field_names)), ctx)
        elif self.__expression_type is AggregateExpression.AVG:
            avg_fn(self.__expr(IndexedTuple.build(t, field_names)), ctx)
        else:
            # Should not happen as its already been checked
            raise Exception(
                "Illegal expression type '{}'. Expression type must be '{}', '{}', or '{}'"
                .format(self.__expression_type, AggregateExpression.SUM,
                        AggregateExpression.COUNT, AggregateExpression.AVG))

        # self.__expr(LabelledTuple(t, field_names), ctx)

        if not isinstance(ctx.result, numbers.Number):
            raise Exception(
                "Illegal aggregate val '{}' of type '{}'. Aggregate expression must evaluate to number"
                .format(ctx.result, type(ctx.result)))
예제 #2
0
def test_labelled_tuple():
    field_names = ['one', 'two']

    t1 = IndexedTuple.build_default(['A', 'B'])

    assert t1['_0'] == 'A'
    assert t1['_1'] == 'B'

    t2 = IndexedTuple.build(['A', 'B'], field_names)

    assert t2['one'] == 'A'
    assert t2['two'] == 'B'
예제 #3
0
    def __on_receive_tuple(self, tuple_, producer_name):
        """Event handler for a received tuple

        :param tuple_: The received tuple
        :return: None
        """

        if self.field_names is None:
            self.field_names = tuple_

            self.send(TupleMessage(tuple_), self.consumers)
            self.producers_received[producer_name] = True
        else:

            if producer_name not in self.producers_received.keys():
                # Will be field names, skip
                self.producers_received[producer_name] = True
            else:
                it = IndexedTuple.build(tuple_, self.field_names)

                idx = int(it[self.map_field_name]) % len(self.consumers)

                self.op_metrics.rows_mapped += 1

                self.send(TupleMessage(tuple_), [self.consumers[idx]])
예제 #4
0
    def execute_py_query(op):
        cur = Cursor(op.s3).select(op.s3key, op.s3sql)
        tuples = cur.execute()
        first_tuple = True
        for t in tuples:

            if op.is_completed():
                break

            op.op_metrics.rows_returned += 1

            if first_tuple:
                # Create and send the record field names
                it = IndexedTuple.build_default(t)
                first_tuple = False

                if op.log_enabled:
                    print("{}('{}') | Sending field names: {}".format(
                        op.__class__.__name__, op.name, it.field_names()))

                op.send(TupleMessage(Tuple(it.field_names())), op.consumers)

            # if op.log_enabled:
            #     print("{}('{}') | Sending field values: {}".format(op.__class__.__name__, op.name, t))

            op.send(TupleMessage(Tuple(t)), op.consumers)
        return cur
예제 #5
0
    def __build_field_names(self):
        """Creates the list of field names from the evaluated aggregates. Field names will just be _0, _1, etc.

        :return: The list of field names.
        """

        return IndexedTuple.build_default(self.__expressions).field_names()
예제 #6
0
def test_join_topk():
    """Tests a top k with a join

    :return: None
    """

    limit = 5

    query_plan = QueryPlan()

    # Query plan
    ts1 = query_plan.add_operator(SQLTableScan('supplier.csv',
                                               'select * from S3Object;', False, 'ts1', query_plan, False))
    ts1_project = query_plan.add_operator(
        Project([ProjectExpression(lambda t_: t_['_3'], 's_nationkey')], 'ts1_project', query_plan, False))
    ts2 = query_plan.add_operator(SQLTableScan('nation.csv',
                                               'select * from S3Object;', False, 'ts2', query_plan, False))
    ts2_project = query_plan.add_operator(
        Project([ProjectExpression(lambda t_: t_['_0'], 'n_nationkey')], 'ts2_project', query_plan, False))
    j = query_plan.add_operator(HashJoin(JoinExpression('s_nationkey', 'n_nationkey'), 'j', query_plan, False))
    t = query_plan.add_operator(Limit(limit, 't', query_plan, False))
    c = query_plan.add_operator(Collate('c', query_plan, False))

    ts1.connect(ts1_project)
    ts2.connect(ts2_project)
    j.connect_left_producer(ts1_project)
    j.connect_right_producer(ts2_project)
    j.connect(t)
    t.connect(c)

    # Write the plan graph
    query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id())

    # Start the query
    query_plan.execute()

    # Assert the results
    # num_rows = 0
    # for t in c.tuples():
    #     num_rows += 1
    #     print("{}:{}".format(num_rows, t))

    c.print_tuples()

    field_names = ['s_nationkey', 'n_nationkey']

    assert len(c.tuples()) == limit + 1

    assert c.tuples()[0] == field_names

    num_rows = 0
    for t in c.tuples():
        num_rows += 1
        # Assert that the nation_key in table 1 has been joined with the record in table 2 with the same nation_key
        if num_rows > 1:
            lt = IndexedTuple.build(t, field_names)
            assert lt['s_nationkey'] == lt['n_nationkey']

    # Write the metrics
    query_plan.print_metrics()
예제 #7
0
    def start(self):

        self.op_metrics.timer_start()

        it = IndexedTuple.build_default(self.col_defs)

        if self.log_enabled:
            print("{}('{}') | Sending field names: {}".format(
                self.__class__.__name__, self.name, it.field_names()))

        self.send(TupleMessage(Tuple(it.field_names())), self.consumers)

        for i in range(0, self.num_rows):

            if self.is_completed():
                break

            self.op_metrics.rows_returned += 1

            t = Tuple()
            for col_def in self.col_defs:
                col_val = col_def.generate()
                t.append(col_val)

            if self.log_enabled:
                print("{}('{}') | Sending field values: {}".format(
                    self.__class__.__name__, self.name, t))

            self.send(TupleMessage(t), self.consumers)

        if not self.is_completed():
            self.complete()

        self.op_metrics.timer_stop()
예제 #8
0
def test_group_count():
    """Tests a group by query with a count aggregate

    :return: None
    """

    num_rows = 0

    query_plan = QueryPlan()

    # Query plan
    # select s_nationkey, count(s_suppkey) from supplier.csv group by s_nationkey
    ts = query_plan.add_operator(
        SQLTableScan('supplier.csv', 'select * from S3Object;', False, 'ts',
                     query_plan, False))

    g = query_plan.add_operator(
        Group(
            ['_3'],
            [
                AggregateExpression(AggregateExpression.COUNT,
                                    lambda t_: t_['_0'])
                # count(s_suppkey)
            ],
            'g',
            query_plan,
            False))

    c = query_plan.add_operator(Collate('c', query_plan, False))

    query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"),
                           gen_test_id())

    ts.connect(g)
    g.connect(c)

    # Start the query
    query_plan.execute()

    # Assert the results
    for _ in c.tuples():
        num_rows += 1
        # print("{}:{}".format(num_rows, t))

    field_names = ['_0', '_1']

    assert c.tuples()[0] == field_names

    assert len(c.tuples()) == 25 + 1

    nation_24 = filter(
        lambda t: IndexedTuple.build(t, field_names)['_0'] == '24',
        c.tuples())[0]
    assert nation_24[1] == 393
    assert num_rows == 25 + 1

    # Write the metrics
    query_plan.print_metrics()
예제 #9
0
    def on_producer_completed(self, producer_name):
        """Handles the event where the producer has completed producing all the tuples it will produce. Once this
        occurs the tuples can be sent to consumers downstream.

        :param producer_name: The producer that has completed
        :return: None
        """
        if producer_name in self.producer_completions.keys():
            self.producer_completions[producer_name] = True
        else:
            raise Exception(
                "Unrecognized producer {} has completed".format(producer_name))

        is_all_producers_done = all(self.producer_completions.values())
        if not is_all_producers_done:
            return

        if not self.use_pandas:
            # Send the field names
            lt = IndexedTuple.build_default(self.group_field_names +
                                            self.aggregate_expressions)
            self.send(TupleMessage(Tuple(lt.field_names())), self.consumers)

            for group_tuple, group_aggregate_contexts in self.group_contexts.items(
            ):

                if self.is_completed():
                    break

                # Convert the aggregate contexts to their results
                group_fields = list(group_tuple)

                group_aggregate_values = list(
                    v.result for v in group_aggregate_contexts.values())

                t_ = group_fields + group_aggregate_values
                self.send(TupleMessage(Tuple(t_)), self.consumers)
        else:
            # for groupby_reducer, aggregate one more time.
            if not self.is_completed() and len(self.producers) > 1:
                self.aggregate_df = self.pd_expr(self.aggregate_df)

            if not self.is_completed() and self.aggregate_df is not None:
                self.aggregate_df.reset_index(drop=True, inplace=True)

                # if self.log_enabled:
                #     with pd.option_context('display.max_rows', None, 'display.max_columns', None):
                #         print("{}('{}') | Sending grouped field values: \n{}"
                #               .format(self.__class__.__name__, self.name, self.aggregate_df))

                #self.send(TupleMessage(Tuple(list(self.aggregate_df))), self.consumers)
                self.send(DataFrameMessage(self.aggregate_df), self.consumers)

                del self.aggregate_df

        Operator.on_producer_completed(self, producer_name)
예제 #10
0
def test_aggregate_count():
    """Tests a group by query with a count aggregate

    :return: None
    """

    query_plan = QueryPlan()

    # Query plan
    # select count(*) from supplier.csv
    ts = query_plan.add_operator(
        SQLTableScan('supplier.csv', 'select * from S3Object;', False, 'ts',
                     query_plan, False))

    a = query_plan.add_operator(
        Aggregate(
            [
                AggregateExpression(AggregateExpression.COUNT,
                                    lambda t_: t_['_0'])
                # count(s_suppkey)
            ],
            'a',
            query_plan,
            False))

    c = query_plan.add_operator(Collate('c', query_plan, False))

    ts.connect(a)
    a.connect(c)

    # Write the plan graph
    query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"),
                           gen_test_id())

    # Start the query
    query_plan.execute()

    # Assert the results
    # num_rows = 0
    # for t in c.tuples():
    #     num_rows += 1
    #     print("{}:{}".format(num_rows, t))

    c.print_tuples()

    field_names = ['_0']

    tuples = c.tuples()

    assert tuples[0] == field_names
    assert IndexedTuple.build(tuples[1], field_names)['_0'] == 10000
    assert len(tuples) == 1 + 1

    # Write the metrics
    query_plan.print_metrics()
예제 #11
0
    def on_receive_tuple(self, tuple_, _producer_name):

        if not self.field_names_index:
            self.field_names_index = IndexedTuple.build_field_names_index(tuple_)
            self.send(TupleMessage(tuple_), self.consumers)
            self.producers_received[_producer_name] = True
        else:

            if _producer_name not in self.producers_received.keys():
                # Will be field names, skip
                self.producers_received[_producer_name] = True
            else:

                if self.hashtable is None:
                    self.hashtable = {}

                self.op_metrics.rows_processed += 1
                it = IndexedTuple(tuple_, self.field_names_index)
                itd = self.hashtable.setdefault(it[self.key], [])
                itd.append(tuple_)
예제 #12
0
    def send_field_names(self, tuple_):
        """Sends the field names tuple

        :param tuple_: The tuple
        :return: None
        """

        # Create and send the record field names
        lt = IndexedTuple.build_default(tuple_)
        labels = Tuple(lt.field_names())

        if self.log_enabled:
            print("{}('{}') | Sending field names [{}]".format(
                self.__class__.__name__, self.name, {'field_names': labels}))

        self.send(TupleMessage(labels), self.consumers)
예제 #13
0
def test_group_sum():
    """Tests a group by query with a sum aggregate

    :return: None
    """

    query_plan = QueryPlan()

    # Query plan
    # select s_nationkey, sum(float(s_acctbal)) from supplier.csv group by s_nationkey
    ts = query_plan.add_operator(
        SQLTableScan('supplier.csv', 'select * from S3Object;', False, 'ts',
                     query_plan, False))

    g = query_plan.add_operator(
        Group(['_3'], [
            AggregateExpression(AggregateExpression.SUM,
                                lambda t_: float(t_['_5']))
        ], 'g', query_plan, False))

    c = query_plan.add_operator(Collate('c', query_plan, False))

    ts.connect(g)
    g.connect(c)

    # Start the query
    query_plan.execute()

    # Assert the results
    # num_rows = 0
    # for t in c.tuples():
    #     num_rows += 1
    #     print("{}:{}".format(num_rows, t_))

    field_names = ['_0', '_1']

    assert c.tuples()[0] == field_names

    assert len(c.tuples()) == 25 + 1

    nation_24 = filter(
        lambda t_: IndexedTuple.build(t_, field_names)['_0'] == '24',
        c.tuples())[0]
    assert round(nation_24[1], 2) == 1833872.56

    # Write the metrics
    query_plan.print_metrics()
예제 #14
0
def test_aggregate_sum():
    """Tests a group by query with a sum aggregate

    :return: None
    """

    query_plan = QueryPlan()

    # Query plan
    # select sum(float(s_acctbal)) from supplier.csv
    ts = query_plan.add_operator(
        SQLTableScan('supplier.csv', 'select * from S3Object;', False, 'ts',
                     query_plan, False))

    a = query_plan.add_operator(
        Aggregate([
            AggregateExpression(AggregateExpression.SUM,
                                lambda t_: float(t_['_5']))
        ], 'a', query_plan, False))

    c = query_plan.add_operator(Collate('c', query_plan, False))

    ts.connect(a)
    a.connect(c)

    # Write the plan graph
    query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"),
                           gen_test_id())

    # Start the query
    query_plan.execute()

    # Assert the results
    # num_rows = 0
    # for t in c.tuples():
    #     num_rows += 1
    #     print("{}:{}".format(num_rows, t))

    field_names = ['_0']

    assert c.tuples()[0] == field_names
    assert round(IndexedTuple.build(c.tuples()[1], field_names)['_0'],
                 2) == 45103548.65
    assert len(c.tuples()) == 1 + 1

    # Write the metrics
    query_plan.print_metrics()
예제 #15
0
    def on_producer_completed(self, producer_name):
        """Event handler for a completed producer. When producers complete the bloom filter can be sent.

        :param producer_name: The producer that completed.
        :return: None
        """

        self.producer_completions[producer_name] = True

        if all(self.producer_completions.values()):

            # Get the SQL from a bloom use operators
            bloom_use_operators = filter(
                lambda o: isinstance(o, SQLTableScanBloomUse), self.consumers)
            bloom_use_sql_strings = map(lambda o: o.s3sql, bloom_use_operators)
            max_bloom_use_sql_strings = max(
                map(lambda s: len(s), bloom_use_sql_strings))

            # Build bloom filter
            best_possible_fp_rate = SlicedSQLBloomFilter.calc_best_fp_rate(
                len(self.__tuples), max_bloom_use_sql_strings)

            if best_possible_fp_rate > self.fp_rate:
                print("{}('{}') | Bloom filter fp rate ({}) too low, "
                      "will exceed max S3 Select SQL expression length ({}). "
                      "Raising to best possible ({})".format(
                          self.__class__.__name__, self.name, self.fp_rate,
                          MAX_S3_SELECT_EXPRESSION_LEN, best_possible_fp_rate))
                fp_rate_to_use = best_possible_fp_rate
            else:
                fp_rate_to_use = self.fp_rate

            bloom_filter = self.build_bloom_filter(len(self.__tuples),
                                                   fp_rate_to_use)

            for t in self.__tuples:
                lt = IndexedTuple.build(t, self.__field_names)
                bloom_filter.add(int(lt[self.bloom_field_name]))

            del self.__tuples

            # Send the bloom filter
            self.__send_bloom_filter(bloom_filter)

        Operator.on_producer_completed(self, producer_name)
예제 #16
0
    def on_receive_tuple(self, tuple_, producer_name):

        # Check the producer is connected
        if self.build_producers is []:
            raise Exception("Left producers are not connected")

        if self.tuple_producers is []:
            raise Exception("Right producer is not connected")

        # Check which producer sent the tuple
        if producer_name in self.build_producers:

            if self.build_field_names is None:
                if self.join_expr.l_field in tuple_:
                    self.build_field_names = tuple_
                    self.field_names_index = IndexedTuple.build_field_names_index(
                        tuple_)
                else:
                    raise Exception(
                        "Join Operator '{}' received invalid left field names tuple {}. "
                        "Tuple must contain join left field name '{}'.".format(
                            self.name, tuple_, self.join_expr.l_field))

        elif producer_name in self.tuple_producers:

            if self.tuple_field_names is None:
                if self.join_expr.r_field in tuple_:
                    self.tuple_field_names = tuple_
                else:
                    raise Exception(
                        "Join Operator '{}' received invalid right field names tuple {}. "
                        "Tuple must contain join right field name '{}'.".
                        format(self.name, tuple_, self.join_expr.r_field))
            else:

                self.op_metrics.tuple_rows_processed += 1

                self.tuples.append(tuple_)

        else:
            raise Exception(
                "Join Operator '{}' received invalid tuple {} from producer '{}'. "
                "Tuple must be sent from connected left producer '{}' or right producer '{}'."
                .format(self.name, tuple_, producer_name, self.build_producers,
                        self.tuple_producers))
예제 #17
0
    def start(self):
        self.op_metrics.timer_start()

        if self.parts == 1:
            self.records = []
            self.worker_metrics = {}
            self.download_part(0, self.records, self.worker_metrics)
        else:
            processes = []
            for part in range(self.parts):
                p = Process(target=self.download_part, args=(part, self.records, self.worker_metrics))
                p.start()
                processes.append(p)

            for p in processes:
                p.join()

        print("All parts finished with {} records".format(len(self.records)))

        first_tuple = True
        for msg in self.records:

            if first_tuple:
                # Create and send the record field names
                it = IndexedTuple.build_default(msg.tuple_)
                first_tuple = False

                if self.log_enabled:
                    print("{}('{}') | Sending field names: {}"
                          .format(self.__class__.__name__, self.name, it.field_names()))

                self.send(TupleMessage(Tuple(it.field_names())), self.consumers)

            self.send(msg, self.consumers)

        self.complete()
        self.op_metrics.timer_stop()
        self.print_stats(to_file=self.s3key + '.' + str(self.parts) +'.stats.txt')

        self.records[:] = []
예제 #18
0
    def __on_receive_tuple(self, tuple_, producer_name):
        """Event handler to handle receipt of a tuple

        :param tuple_: The tuple
        :return: None
        """

        assert (len(tuple_) > 0)

        if not self.field_names_index:
            self.field_names_index = IndexedTuple.build_field_names_index(
                tuple_)
            self.producers_received[producer_name] = True
            self.__send_field_names(tuple_)
        else:
            if producer_name not in self.producers_received.keys():
                # This will be the field names tuple, skip it
                self.producers_received[producer_name] = True
            else:
                if self.__evaluate_filter(tuple_):
                    self.op_metrics.rows_filtered += 1
                    self.__send_field_values(tuple_)
예제 #19
0
def test_aggregate():
    """Executes a select with an aggregate.

    :return: None
    """

    num_rows = 0

    cur = Cursor(boto3.client('s3')) \
        .select('region.csv', 'select count(*) from S3Object')

    try:
        rows = cur.execute()
        for r in rows:
            num_rows += 1
            lt = IndexedTuple.build_default(r)
            assert lt['_0'] == '5'
            # print("{}:{}".format(num_rows, r))

        assert num_rows == 1
    finally:
        cur.close()
예제 #20
0
def test_where_predicate():
    """Executes a select with a where clause on one of the attributes.

    :return: None
    """

    num_rows = 0

    cur = Cursor(boto3.client('s3'))\
        .select('region.csv', 'select * from S3Object where r_name = \'AMERICA\';')

    try:
        rows = cur.execute()
        for r in rows:
            num_rows += 1
            lt = IndexedTuple.build_default(r)
            assert lt['_1'] == 'AMERICA'
            # print("{}:{}".format(num_rows, r))

        assert num_rows == 1
    finally:
        cur.close()
예제 #21
0
    def eval(self, tuple_, field_names_index):
        """Evaluates the predicate using the given tuple

        :param tuple_: The tuple to evaluate the expression against
        :param field_names_index: The names of the fields in the tuple
        :return: True or false
        """

        if self.expr:
            it = IndexedTuple(tuple_, field_names_index)

            v = self.expr(it)

            if type(v) is not bool and type(v) is not numpy.bool_:
                raise Exception(
                    "Illegal return type '{}'. "
                    "Predicate expression must evaluate to {} or {}".format(
                        type(v), bool, numpy.bool_))

            return v

        else:
            raise NotImplementedError
예제 #22
0
    def on_receive_tuple(self, tuple_, producer_name):
        """Handles the receipt of a tuple. The tuple is mapped to a new tuple using the given projection expressions.
        The field names are modified according to the new field names in the projection expressions.

        :param producer_name:
        :param tuple_: The received tuple
        :return: None
        """

        assert (len(tuple_) > 0)

        if not self.field_names_index:

            self.field_names_index = IndexedTuple.build_field_names_index(
                tuple_)

            # Map the old field names to the new
            projected_field_names = []
            for e in self.project_exprs:
                fn = e.new_field_name
                projected_field_names.append(fn)

            if self.log_enabled:
                print(
                    "{}('{}') | Sending projected field names: from: {} to: {}"
                    .format(self.__class__.__name__, self.name, tuple_,
                            projected_field_names))

            self.producers_received[producer_name] = True

            assert (len(projected_field_names) == len(self.project_exprs))

            self.send(TupleMessage(Tuple(projected_field_names)),
                      self.consumers)

        else:

            assert (len(tuple_) == len(self.field_names_index))

            if producer_name not in self.producers_received.keys():
                # This will be the field names tuple, skip it
                self.producers_received[producer_name] = True
            else:

                # Perform the projection using the given expressions
                it = IndexedTuple(tuple_, self.field_names_index)

                projected_field_values = []
                for e in self.project_exprs:
                    fv = e.expr(it)
                    projected_field_values.append(fv)

                self.op_metrics.rows_projected += 1

                if self.log_enabled:
                    print(
                        "{}('{}') | Sending projected field values: from: {} to: {}"
                        .format(self.__class__.__name__, self.name, tuple_,
                                projected_field_values))

                assert (len(projected_field_values) == len(self.project_exprs))

                self.send(TupleMessage(Tuple(projected_field_values)),
                          self.consumers)
예제 #23
0
def test_sort_topk():
    """Executes a sorted top k query, which must use the top operator as record sorting can't be pushed into s3. The
    results are collated.

    :return: None
    """

    limit = 5

    query_plan = QueryPlan()

    # Query plan
    ts = query_plan.add_operator(SQLTableScan('supplier.csv',
                                              'select * from S3Object;',
                                              False,
                                              'ts',
                                              query_plan,
                                              False))
    s = query_plan.add_operator(Sort([
        SortExpression('_5', float, 'ASC')
    ], 's', query_plan, False))
    t = query_plan.add_operator(Limit(limit, 't', query_plan, False))
    c = query_plan.add_operator(Collate('c', query_plan, False))

    ts.connect(s)
    s.connect(t)
    t.connect(c)

    # Write the plan graph
    query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id())

    # Start the query
    query_plan.execute()

    # Assert the results
    # num_rows = 0
    # for t in c.tuples():
    #     num_rows += 1
    #     print("{}:{}".format(num_rows, t))

    field_names = ['_0', '_1', '_2', '_3', '_4', '_5', '_6']

    assert len(c.tuples()) == limit + 1

    assert c.tuples()[0] == field_names

    prev = None
    num_rows = 0
    for t in c.tuples():
        num_rows += 1
        # print("{}:{}".format(num_rows, t))
        if num_rows > 1:
            if prev is None:
                prev = t
            else:
                lt = IndexedTuple.build(t, field_names)
                prev_lt = IndexedTuple.build(prev, field_names)
                assert float(lt['_5']) > float(prev_lt['_5'])

    # Write the metrics
    query_plan.print_metrics()
예제 #24
0
def run(parallel, buffer_size):
    """Tests a with sharded operators and separate build and probe for join

    :return: None
    """

    query_plan = QueryPlan(is_async=parallel, buffer_size=buffer_size)

    # Query plan
    parts = 2

    collate = query_plan.add_operator(Collate('collate', query_plan, False))
    merge = query_plan.add_operator(Merge('merge', query_plan, False))

    hash_join_build_ops = []
    hash_join_probe_ops = []
    for p in range(1, parts + 1):
        r_region_key_lower = math.ceil((5.0 / float(parts)) * (p - 1))
        r_region_key_upper = math.ceil((5.0 / float(parts)) * p)

        region_scan = query_plan.add_operator(
            SQLTableScan('region.csv',
                         'select * '
                         'from S3Object '
                         'where cast(r_regionkey as int) >= {} and cast(r_regionkey as int) < {};'
                         .format(r_region_key_lower, r_region_key_upper),
                         False, 'region_scan' + '_' + str(p),
                         query_plan,
                         False))

        region_project = query_plan.add_operator(
            Project([
                ProjectExpression(lambda t_: t_['_0'], 'r_regionkey'),
                ProjectExpression(lambda t_: t_['_1'], 'r_name')
            ], 'region_project' + '_' + str(p), query_plan, False))

        n_nation_key_lower = math.ceil((25.0 / float(parts)) * (p - 1))
        n_nation_key_upper = math.ceil((25.0 / float(parts)) * p)

        nation_scan = query_plan.add_operator(
            SQLTableScan('nation.csv',
                         'select * from S3Object '
                         'where cast(n_nationkey as int) >= {} and cast(n_nationkey as int) < {};'
                         .format(n_nation_key_lower, n_nation_key_upper),
                         False, 'nation_scan' + '_' + str(p),
                         query_plan,
                         False))

        nation_project = query_plan.add_operator(
            Project([
                ProjectExpression(lambda t_: t_['_0'], 'n_nationkey'),
                ProjectExpression(lambda t_: t_['_1'], 'n_name'),
                ProjectExpression(lambda t_: t_['_2'], 'n_regionkey')
            ], 'nation_project' + '_' + str(p), query_plan, False))

        region_hash_join_build = query_plan.add_operator(
            HashJoinBuild('r_regionkey', 'region_hash_join_build' + '_' + str(p), query_plan, False))
        hash_join_build_ops.append(region_hash_join_build)

        region_nation_join_probe = query_plan.add_operator(
            HashJoinProbe(JoinExpression('r_regionkey', 'n_regionkey'), 'region_nation_join_probe' + '_' + str(p),
                          query_plan, False))
        hash_join_probe_ops.append(region_nation_join_probe)

        region_scan.connect(region_project)
        nation_scan.connect(nation_project)
        region_project.connect(region_hash_join_build)
        region_nation_join_probe.connect_tuple_producer(nation_project)
        region_nation_join_probe.connect(merge)

    for probe_op in hash_join_probe_ops:
        for build_op in hash_join_build_ops:
            probe_op.connect_build_producer(build_op)

    merge.connect(collate)

    # Write the plan graph
    query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id())

    # Start the query
    query_plan.execute()

    tuples = collate.tuples()

    collate.print_tuples(tuples)

    # Write the metrics
    query_plan.print_metrics()

    # Shut everything down
    query_plan.stop()

    field_names = ['r_regionkey', 'r_name', 'n_nationkey', 'n_name', 'n_regionkey']

    assert len(tuples) == 25 + 1

    assert tuples[0] == field_names

    num_rows = 0
    for t in tuples:
        num_rows += 1
        # Assert that the nation_key in table 1 has been joined with the record in table 2 with the same nation_key
        if num_rows > 1:
            lt = IndexedTuple.build(t, field_names)
            assert lt['r_regionkey'] == lt['n_regionkey']
예제 #25
0
 def set_field_names(self, field_names):
     self.__field_name_index = IndexedTuple.build_field_names_index(
         field_names)
예제 #26
0
def test_join_baseline_pandas():
    """Tests a join

    :return: None
    """

    query_plan = QueryPlan(is_async=True, buffer_size=0)

    # Query plan
    supplier_scan = query_plan.add_operator(
        SQLTableScan('region.csv', 'select * from S3Object;', True, False, False, 'supplier_scan', query_plan, True))

    def supplier_project_fn(df):
        df = df.filter(['_0'], axis='columns')
        df = df.rename(columns={'_0': 'r_regionkey'})
        return df

    supplier_project = query_plan.add_operator(
        Project([ProjectExpression(lambda t_: t_['_0'], 'r_regionkey')], 'supplier_project', query_plan, True, supplier_project_fn))

    nation_scan = query_plan.add_operator(
        SQLTableScan('nation.csv', 'select * from S3Object;', True, False, False, 'nation_scan', query_plan, True))

    def nation_project_fn(df):
        df = df.filter(['_2'], axis='columns')
        df = df.rename(columns={'_2': 'n_regionkey'})
        return df

    nation_project = query_plan.add_operator(
        Project([ProjectExpression(lambda t_: t_['_2'], 'n_regionkey')], 'nation_project', query_plan, True, nation_project_fn))

    supplier_nation_join_build = query_plan.add_operator(
        HashJoinBuild('n_regionkey', 'supplier_nation_join_build', query_plan, True))

    supplier_nation_join_probe = query_plan.add_operator(
        HashJoinProbe(JoinExpression('n_regionkey', 'r_regionkey'), 'supplier_nation_join_probe', query_plan, True))

    collate = query_plan.add_operator(Collate('collate', query_plan, True))

    supplier_scan.connect(supplier_project)
    nation_scan.connect(nation_project)
    nation_project.connect(supplier_nation_join_build)
    supplier_nation_join_probe.connect_build_producer(supplier_nation_join_build)
    supplier_nation_join_probe.connect_tuple_producer(supplier_project)
    supplier_nation_join_probe.connect(collate)

    # Write the plan graph
    query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id())

    # Start the query
    query_plan.execute()

    tuples = collate.tuples()

    collate.print_tuples(tuples)

    # Write the metrics
    query_plan.print_metrics()

    # Shut everything down
    query_plan.stop()

    field_names = ['n_regionkey', 'r_regionkey']

    assert len(tuples) == 25 + 1

    assert tuples[0] == field_names

    num_rows = 0
    for t in tuples:
        num_rows += 1
        # Assert that the nation_key in table 1 has been joined with the record in table 2 with the same nation_key
        if num_rows > 1:
            lt = IndexedTuple.build(t, field_names)
            assert lt['n_regionkey'] == lt['r_regionkey']
예제 #27
0
    def join_field_values(self):
        """Performs the join on data tuples using a nested loop joining algorithm. The joined tuples are each sent.
        Allows for the loop to be broken if the operator completes while executing.

        :return: None
        """

        # Determine which direction the hash join should run
        # The larger relation should remain as a list and the smaller relation should be hashed. If either of the
        # relations are empty then just return
        if len(self.l_tuples) == 0 or len(self.r_tuples) == 0:
            return
        elif len(self.l_tuples) > len(self.r_tuples):
            l_to_r = True
            # r_to_l = not l_to_r
        else:
            l_to_r = False
            # r_to_l = not l_to_r

        if l_to_r:
            outer_tuples_list = self.l_tuples
            inner_tuples_list = self.r_tuples
            inner_tuple_field_name = self.join_expr.r_field
            inner_tuple_field_names = self.r_field_names
            outer_tuple_field_index = self.l_field_names.index(
                self.join_expr.l_field)
        else:
            outer_tuples_list = self.r_tuples
            inner_tuples_list = self.l_tuples
            inner_tuple_field_name = self.join_expr.l_field
            inner_tuple_field_names = self.l_field_names
            outer_tuple_field_index = self.r_field_names.index(
                self.join_expr.r_field)

        # Hash the tuples from the smaller set of tuples
        inner_tuples_dict = {}
        for t in inner_tuples_list:
            it = IndexedTuple.build(t, inner_tuple_field_names)
            itd = inner_tuples_dict.setdefault(it[inner_tuple_field_name], [])
            itd.append(t)

        for outer_tuple in outer_tuples_list:

            if self.is_completed():
                break

            outer_tuple_field_value = outer_tuple[outer_tuple_field_index]
            inner_tuples = inner_tuples_dict.get(outer_tuple_field_value, None)

            # if self.log_enabled:
            #     print("{}('{}') | Joining Outer: {} Inner: {}".format(
            #         self.__class__.__name__,
            #         self.name,
            #         outer_tuple,
            #         inner_tuples))

            if inner_tuples is not None:

                for inner_tuple in inner_tuples:

                    if l_to_r:
                        t = outer_tuple + inner_tuple
                    else:
                        t = inner_tuple + outer_tuple

                    # if self.log_enabled:
                    #     print("{}('{}') | Sending field values [{}]".format(
                    #         self.__class__.__name__,
                    #         self.name,
                    #         {'data': t}))

                    self.op_metrics.rows_joined += 1

                    self.send(TupleMessage(Tuple(t)), self.consumers)
예제 #28
0
def test_r_to_l_join():
    """Tests a join

    :return: None
    """

    query_plan = QueryPlan()

    # Query plan
    supplier_scan = query_plan.add_operator(
        SQLTableScan('supplier.csv', 'select * from S3Object;', False, 'supplier_scan', query_plan, False))

    supplier_project = query_plan.add_operator(
        Project([ProjectExpression(lambda t_: t_['_3'], 's_nationkey')], 'supplier_project', query_plan, False))

    nation_scan = query_plan.add_operator(
        SQLTableScan('nation.csv', 'select * from S3Object;', False, 'nation_scan', query_plan, False))

    nation_project = query_plan.add_operator(
        Project([ProjectExpression(lambda t_: t_['_0'], 'n_nationkey')], 'nation_project', query_plan, False))

    supplier_nation_join = query_plan.add_operator(
        HashJoin(JoinExpression('n_nationkey', 's_nationkey'), 'supplier_nation_join', query_plan, False))

    collate = query_plan.add_operator(Collate('collate', query_plan, False))

    supplier_scan.connect(supplier_project)
    nation_scan.connect(nation_project)
    supplier_nation_join.connect_left_producer(nation_project)
    supplier_nation_join.connect_right_producer(supplier_project)
    supplier_nation_join.connect(collate)

    # Write the plan graph
    query_plan.write_graph(os.path.join(ROOT_DIR, "../tests-output"), gen_test_id())

    # Start the query
    query_plan.execute()

    # Assert the results
    # num_rows = 0
    # for t in collate.tuples():
    #     num_rows += 1
    #     print("{}:{}".format(num_rows, t))

    # collate.print_tuples()

    field_names = ['n_nationkey', 's_nationkey']

    assert len(collate.tuples()) == 10000 + 1

    assert collate.tuples()[0] == field_names

    num_rows = 0
    for t in collate.tuples():
        num_rows += 1
        # Assert that the nation_key in table 1 has been joined with the record in table 2 with the same nation_key
        if num_rows > 1:
            lt = IndexedTuple.build(t, field_names)
            assert lt['s_nationkey'] == lt['n_nationkey']

    # Write the metrics
    query_plan.print_metrics()