Exemple #1
0
class Client(object):
    def __init__(
        self,
        crosscat_host=None,
        crosscat_port=8007,
        crosscat_engine_type="multiprocessing",
        bayesdb_host=None,
        bayesdb_port=8008,
        seed=None,
        upgrade_key_column=None,
    ):
        """
        Create a client object. The client creates a parser, that is uses to parse all commands,
        and an engine, which is uses to execute all commands. The engine can be remote or local.
        If local, the engine will be created.
        """
        self.parser = Parser()
        if bayesdb_host is None or bayesdb_host == "localhost":
            self.online = False
            self.engine = Engine(crosscat_host, crosscat_port, crosscat_engine_type, seed)
            self.engine.upgrade_btables(upgrade_key_column)
        else:
            self.online = True
            self.hostname = bayesdb_host
            self.port = bayesdb_port
            self.URI = "http://" + self.hostname + ":%d" % self.port

    def call_bayesdb_engine(self, method_name, args_dict, debug=False):
        """
        Helper function used to call the BayesDB engine, whether it is remote or local.
        Accepts method name and arguments for that method as input.
        """
        if self.online:
            out, id = aqupi_utils.call(method_name, args_dict, self.URI)
        else:
            method = getattr(self.engine, method_name)
            if debug:
                out = method(**args_dict)
            else:
                # when not in debug mode, catch all BayesDBErrors
                try:
                    out = method(**args_dict)
                except utils.BayesDBError as e:
                    out = dict(message=str(e), error=True)
        return out

    def __call__(
        self,
        call_input,
        pretty=True,
        timing=False,
        wait=False,
        plots=None,
        yes=False,
        debug=False,
        pandas_df=None,
        pandas_output=True,
        key_column=None,
    ):
        """Wrapper around execute."""
        return self.execute(call_input, pretty, timing, wait, plots, yes, debug, pandas_df, pandas_output, key_column)

    def execute(
        self,
        call_input,
        pretty=True,
        timing=False,
        wait=False,
        plots=None,
        yes=False,
        debug=False,
        pandas_df=None,
        pandas_output=True,
        key_column=None,
    ):
        """
        Execute a chunk of BQL. This method breaks a large chunk of BQL (like a file)
        consisting of possibly many BQL statements, breaks them up into individual statements,
        then passes each individual line to self.execute_statement() as a string.
        
        param call_input: may be either a file object, or a string.
        If the input is a file, then we load the inputs of the file, and use those as a string.

        See self.execute_statement() for an explanation of arguments.
        """
        if type(call_input) == file:
            bql_string = call_input.read()
            path = os.path.abspath(call_input.name)
            self.parser.set_root_dir(os.path.dirname(path))
        elif type(call_input) == str:
            bql_string = call_input
        else:
            print "Invalid input type: expected file or string."

        return_list = []

        # Parse input, but catch parsing errors and abort
        try:
            lines = [bql_statement_ast for bql_statement_ast in self.parser.pyparse_input(bql_string)]
        except utils.BayesDBError as e:
            if debug:
                raise e
            else:
                print str(e)
                return

        # Iterate through lines with while loop so we can append within loop.
        while len(lines) > 0:
            line = lines.pop(0)
            if type(call_input) == file:
                print "> %s" % line
            if wait:
                user_input = raw_input()
                if len(user_input) > 0 and (user_input[0] == "q" or user_input[0] == "s"):
                    continue
            result = self.execute_statement(
                line,
                pretty=pretty,
                timing=timing,
                plots=plots,
                yes=yes,
                debug=debug,
                pandas_df=pandas_df,
                pandas_output=pandas_output,
                key_column=key_column,
            )

            if type(result) == dict and "message" in result and result["message"] == "execute_file":
                ## special case for one command: execute_file
                new_lines = self.parser.split_lines(result["bql_string"])
                lines += new_lines
            if type(call_input) == file:
                print

            return_list.append(result)

        self.parser.reset_root_dir()

        if not pretty:
            return return_list

    def execute_statement(
        self,
        bql_statement_ast,
        pretty=True,
        timing=False,
        plots=None,
        yes=False,
        debug=False,
        pandas_df=None,
        pandas_output=True,
        key_column=None,
    ):
        """
        Accepts a SINGLE BQL STATEMENT as input, parses it, and executes it if it was parsed
        successfully.

        If pretty=True, then the command output will be pretty-printed as a string.
        If pretty=False, then the command output will be returned as a python object.

        timing=True prints out how long the command took to execute.

        For commands that have visual results, plots=True will cause those to be displayed
        by matplotlib as graphics rather than being pretty-printed as text.
        (Note that the graphics will also be saved if the user added SAVE TO <filename> to the BQL.)
        """
        if timing:
            start_time = time.time()

        parser_out = None
        ##TODO move pyparsing objects out of client into parser
        if debug:
            parser_out = self.parser.parse_single_statement(bql_statement_ast)
        else:
            try:
                parser_out = self.parser.parse_single_statement(bql_statement_ast)
            except Exception as e:
                raise utils.BayesDBParseError(str(e))
        if parser_out is None:
            print "Could not parse command. Try typing 'help' for a list of all commands."
            return
        elif not parser_out:
            return

        method_name, args_dict, client_dict = parser_out
        if client_dict is None:
            client_dict = {}

        ## Do stuff now that you know the user's command, but before passing it to engine.
        if method_name == "execute_file":
            return dict(message="execute_file", bql_string=open(args_dict["filename"], "r").read())
        elif (method_name == "drop_btable") and (not yes):
            ## If dropping something, ask for confirmation.
            print "Are you sure you want to permanently delete this btable, and all associated models, without any way to get them back? Enter 'y' if yes."
            user_confirmation = raw_input()
            if "y" != user_confirmation.strip():
                return dict(message="Operation canceled by user.")
        elif (method_name == "drop_models") and (not yes):
            ## If dropping something, ask for confirmation.
            print "Are you sure you want to permanently delete model(s), without any way to get them back? Enter 'y' if yes."
            user_confirmation = raw_input()
            if "y" != user_confirmation.strip():
                return dict(message="Operation canceled by user.")
        elif method_name == "load_models":
            pklpath = client_dict["pkl_path"]
            try:
                models = pickle.load(gzip.open(self.parser.get_absolute_path(pklpath), "rb"))
            except IOError as e:
                if pklpath[-7:] != ".pkl.gz":
                    if pklpath[-4:] == ".pkl":
                        models = pickle.load(open(self.parser.get_absolute_path(pklpath), "rb"))
                    else:
                        pklpath = pklpath + ".pkl.gz"
                        models = pickle.load(gzip.open(self.parser.get_absolute_path(pklpath), "rb"))
                else:
                    raise utils.BayesDBError("Models file %s could not be found." % pklpath)
            args_dict["models"] = models
        elif method_name == "create_btable":
            if pandas_df is None:
                header, rows = data_utils.read_csv(client_dict["csv_path"])
            else:
                header, rows = data_utils.read_pandas_df(pandas_df)
            args_dict["header"] = header
            args_dict["raw_T_full"] = rows
            args_dict["key_column"] = key_column
            args_dict["subsample"] = False

            # Display warning messages and get confirmation if btable is too large.
            # Ask user if they want to turn on subsampling.
            max_columns = 200
            max_rows = 1000
            max_cells = 100000
            message = None
            if not yes:
                if len(rows[0]) > max_columns:
                    message = (
                        "The btable you are uploading has %d columns, but BayesDB is currently designed to support only %d columns. If you proceed, performance may suffer unless you set many columns' datatypes to 'ignore'. Would you like to continue? Enter 'y' if yes."
                        % (len(rows[0]), max_columns)
                    )
                if len(rows) > max_rows:
                    message = (
                        "The btable you are uploading has %d rows, but BayesDB is currently designed to support only %d rows. If you proceed, performance may suffer. Would you like to continue? Enter 'y' to continue without subsampling, 'n' to abort, 's' to continue by subsampling %d rows, or a positive integer to specify the number of rows to be subsampled."
                        % (len(rows), max_rows, max_rows)
                    )
                if len(rows[0]) * len(rows) > max_cells:
                    message = (
                        "The btable you are uploading has %d cells, but BayesDB is currently designed to support only %d cells. If you proceed, performance may suffer unless you enable subsampling. Enter 'y' to continue without subsampling, 'n' to abort, 's' to continue by subsampling %d rows, or a positive integer to specify the number of rows to be subsampled."
                        % (len(rows) * len(rows[0]), max_cells, max_rows)
                    )
                if message is not None:
                    print message
                    user_confirmation = raw_input()
                    if "y" == user_confirmation.strip():
                        pass
                    elif "n" == user_confirmation.strip():
                        return dict(message="Operation canceled by user.")
                    elif "s" == user_confirmation.strip():
                        args_dict["subsample"] = min(max_rows, len(rows))
                    elif utils.is_int(user_confirmation.strip()):
                        args_dict["subsample"] = int(user_confirmation.strip())
                    else:
                        return dict(message="Operation canceled by user.")
        elif method_name in ["label_columns", "update_metadata"]:
            if client_dict["source"] == "file":
                header, rows = data_utils.read_csv(client_dict["csv_path"])
                args_dict["mappings"] = {key: value for key, value in rows}

        ## Call engine.
        result = self.call_bayesdb_engine(method_name, args_dict, debug)

        ## If error occurred, exit now.
        if "error" in result and result["error"]:
            if pretty:
                print result["message"]
                return result["message"]
            else:
                return result

        ## Do stuff now that engine has given you output, but before printing the result.
        result = self.callback(method_name, args_dict, client_dict, result)

        assert type(result) != int

        if timing:
            end_time = time.time()
            print "Elapsed time: %.2f seconds." % (end_time - start_time)

        if plots is None:
            plots = "DISPLAY" in os.environ.keys()

        if "matrix" in result and (plots or client_dict["filename"]):
            # Plot matrices
            plotting_utils.plot_matrix(
                result["matrix"], result["column_names"], result["title"], client_dict["filename"]
            )
            if pretty:
                if "column_lists" in result:
                    print self.pretty_print(dict(column_lists=result["column_lists"]))
                return self.pretty_print(result)
            else:
                return result
        if "plot" in client_dict and client_dict["plot"]:
            if plots or client_dict["filename"]:
                # Plot generalized histograms or scatterplots
                plot_remove_key = method_name in ["select", "infer"]
                plotting_utils.plot_general_histogram(
                    result["columns"],
                    result["data"],
                    result["M_c"],
                    client_dict["filename"],
                    client_dict["scatter"],
                    remove_key=plot_remove_key,
                )
                return self.pretty_print(result)
            else:
                if "message" not in result:
                    result["message"] = ""
                result["message"] = (
                    "Your query indicates that you would like to make a plot, but in order to do so, you must either enable plotting in a window or specify a filename to save to by appending 'SAVE TO <filename>' to this command.\n"
                    + result["message"]
                )

        if pretty:
            pp = self.pretty_print(result)
            print pp

        if pandas_output and "data" in result and "columns" in result:
            result_pandas_df = data_utils.construct_pandas_df(result)
            return result_pandas_df
        else:
            return result

    def callback(self, method_name, args_dict, client_dict, result):
        """
        This method is meant to be called after receiving the result of a
        call to the BayesDB engine, and modifies the output before it is displayed
        to the user.
        """
        if method_name == "save_models":
            samples_dict = result
            ## Here is where the models get saved.
            pkl_path = client_dict["pkl_path"]
            if pkl_path[-7:] != ".pkl.gz":
                if pkl_path[-4:] == ".pkl":
                    pkl_path = pkl_path + ".gz"
                else:
                    pkl_path = pkl_path + ".pkl.gz"
            samples_file = gzip.GzipFile(pkl_path, "w")
            pickle.dump(samples_dict, samples_file)
            return dict(message="Successfully saved the samples to %s" % client_dict["pkl_path"])

        else:
            return result

    def pretty_print(self, query_obj):
        """
        Return a pretty string representing the output object.
        """
        assert type(query_obj) == dict
        result = ""
        if type(query_obj) == dict and "message" in query_obj:
            result += query_obj["message"] + "\n"
        if "data" in query_obj and "columns" in query_obj:
            """ Pretty-print data table """
            pt = prettytable.PrettyTable()
            pt.field_names = query_obj["columns"]
            for row in query_obj["data"]:
                pt.add_row(row)
            result += str(pt)
        elif "list" in query_obj:
            """ Pretty-print lists """
            result += str(query_obj["list"])
        elif "column_names" in query_obj:
            """ Pretty-print cctypes """
            colnames = query_obj["column_names"]
            zmatrix = query_obj["matrix"]
            pt = prettytable.PrettyTable(hrules=prettytable.ALL, vrules=prettytable.ALL, header=False)
            pt.add_row([""] + list(colnames))
            for row, colname in zip(zmatrix, list(colnames)):
                pt.add_row([colname] + list(row))
            result += str(pt)
        elif "columns" in query_obj:
            """ Pretty-print column list."""
            pt = prettytable.PrettyTable()
            pt.field_names = ["column"]
            for column in query_obj["columns"]:
                pt.add_row([column])
            result += str(pt)
        elif "row_lists" in query_obj:
            """ Pretty-print multiple row lists, which are just names and row sizes. """
            pt = prettytable.PrettyTable()
            pt.field_names = ("Row List Name", "Row Count")

            def get_row_list_sorting_key(x):
                """ To be used as the key function in a sort. Puts cc_2 ahead of cc_10, e.g. """
                name, count = x
                if "_" not in name:
                    return name
                s = name.split("_")
                end = s[-1]
                start = "_".join(s[:-1])
                if utils.is_int(end):
                    return (start, int(end))
                return name

            for name, count in sorted(query_obj["row_lists"], key=get_row_list_sorting_key):
                pt.add_row((name, count))
            result += str(pt)
        elif "column_lists" in query_obj:
            """ Pretty-print multiple column lists. """
            print
            clists = query_obj["column_lists"]
            for name, clist in clists:
                print "%s:" % name
                pt = prettytable.PrettyTable()
                pt.field_names = clist
                print pt
        elif "models" in query_obj:
            """ Pretty-print model info. """
            pt = prettytable.PrettyTable()
            pt.field_names = ("model_id", "iterations")
            for (id, iterations) in query_obj["models"]:
                pt.add_row((id, iterations))
            result += str(pt)

        if len(result) >= 1 and result[-1] == "\n":
            result = result[:-1]
        return result