Exemplo n.º 1
0
    def load_cur_versions(self):
        """Load current version numbers from .upgrade.sql files."""

        vrc = re.compile(r"^ \s+ return \s+ '([0-9.]+)';", re.X | re.I | re.M)
        for s in AUTO_UPGRADE:
            fn = '%s.upgrade.sql' % s
            fqfn = skytools.installer_find_file(fn)
            try:
                f = open(fqfn, 'r')
            except IOError, d:
                raise skytools.UsageError(
                    '%s: cannot find upgrade file: %s [%s]' %
                    (s, fqfn, str(d)))

            sql = f.read()
            f.close()
            m = vrc.search(sql)
            if not m:
                raise skytools.UsageError('%s: failed to detect version' %
                                          fqfn)

            ver = m.group(1)
            cur = [s, ver, fn, None]
            self.log.info("Loaded %s %s from %s", s, ver, fqfn)
            version_list.append(cur)
Exemplo n.º 2
0
def find_copy_source(script, queue_name, copy_table_name, node_name,
                     node_location):
    """Find source node for table.

    @param script: DbScript
    @param queue_name: name of the cascaded queue
    @param copy_table_name: name of the table
    @param node_name: target node name
    @param node_location: target node location
    @returns (node_name, node_location) of source node
    """
    while 1:
        src_db = script.get_database('_source_db',
                                     connstr=node_location,
                                     autocommit=1)
        src_curs = src_db.cursor()

        q = "select * from pgq_node.get_node_info(%s)"
        src_curs.execute(q, [queue_name])
        info = src_curs.fetchone()
        if info['ret_code'] >= 400:
            raise skytools.UsageError("Node does not exists")

        script.log.info("Checking if %s can be used for copy",
                        info['node_name'])

        q = "select table_name, local, table_attrs from londiste.get_table_list(%s) where table_name = %s"
        src_curs.execute(q, [queue_name, copy_table_name])
        got = False
        for row in src_curs.fetchall():
            tbl = row['table_name']
            if tbl != copy_table_name:
                continue
            if not row['local']:
                script.log.debug("Problem: %s is not local", tbl)
                continue
            if not handler_allows_copy(row['table_attrs']):
                script.log.debug(
                    "Problem: %s handler does not store data [%s]", tbl,
                    row['table_attrs'])
                continue
            script.log.debug("Good: %s is usable", tbl)
            got = True
            break

        script.close_database('_source_db')

        if got:
            script.log.info("Node %s seems good source, using it",
                            info['node_name'])
            return node_name, node_location

        if info['node_type'] == 'root':
            raise skytools.UsageError("Found root and no source found")

        # walk upwards
        node_name = info['provider_node']
        node_location = info['provider_location']
Exemplo n.º 3
0
    def add_seq(self, src_db, dst_db, seq, create_flags):
        src_curs = src_db.cursor()
        dst_curs = dst_db.cursor()
        seq_exists = skytools.exists_sequence(dst_curs, seq)
        if create_flags:
            if seq_exists:
                self.log.info('Sequence %s already exist, not creating', seq)
            else:
                if not skytools.exists_sequence(src_curs, seq):
                    # sequence not present on provider - nowhere to get the DDL from
                    self.log.warning(
                        'Sequence "%s" missing on provider, skipping', seq)
                    return
                s = skytools.SeqStruct(src_curs, seq)
                src_db.commit()
                s.create(dst_curs, create_flags, log=self.log)
        elif not seq_exists:
            if self.options.skip_non_existing:
                self.log.warning(
                    'Sequence "%s" missing on local node, skipping', seq)
                return
            else:
                raise skytools.UsageError("Sequence %r missing on local node",
                                          seq)

        q = "select * from londiste.local_add_seq(%s, %s)"
        self.exec_cmd(dst_curs, q, [self.set_name, seq])
Exemplo n.º 4
0
    def reload(self):
        super(LocalConsumer, self).reload()

        self.local_tracking_file = self.cf.getfile('local_tracking_file')
        if not os.path.exists(os.path.dirname(self.local_tracking_file)):
            raise skytools.UsageError("path does not exist: %s" %
                                      self.local_tracking_file)
Exemplo n.º 5
0
 def execute_with_retry(self, dbname, stmt, args, exceptions=None):
     """ Execute SQL and retry if it fails.
     Return number of retries and current valid cursor, or raise an exception.
     """
     sql_retry = self.cf.getbool("sql_retry", False)
     sql_retry_max_count = self.cf.getint("sql_retry_max_count", 10)
     sql_retry_max_time = self.cf.getint("sql_retry_max_time", 300)
     sql_retry_formula_a = self.cf.getint("sql_retry_formula_a", 1)
     sql_retry_formula_b = self.cf.getint("sql_retry_formula_b", 5)
     sql_retry_formula_cap = self.cf.getint("sql_retry_formula_cap", 60)
     elist = exceptions or tuple()
     stime = time.time()
     tried = 0
     dbc = None
     while True:
         try:
             if dbc is None:
                 if dbname not in self.db_cache:
                     self.get_database(dbname, autocommit=1)
                 dbc = self.db_cache[dbname]
                 if dbc.isolation_level != skytools.I_AUTOCOMMIT:
                     raise skytools.UsageError(
                         "execute_with_retry: autocommit required")
             else:
                 dbc.reset()
             curs = dbc.get_connection(dbc.isolation_level).cursor()
             curs.execute(stmt, args)
             break
         except elist, e:
             if not sql_retry or tried >= sql_retry_max_count or time.time(
             ) - stime >= sql_retry_max_time:
                 raise
             self.log.info("Job %s got error on connection %s: %s" %
                           (self.job_name, dbname, e))
         except:
Exemplo n.º 6
0
    def __init__(self, service_name, db_name, args):
        """Initialize new consumer.

        @param service_name: service_name for DBScript
        @param db_name: name of database for get_database()
        @param args: cmdline args for DBScript
        """

        super(BaseConsumer, self).__init__(service_name, args)

        self.db_name = db_name

        # compat params
        self.consumer_name = self.cf.get("pgq_consumer_id", '')
        self.queue_name = self.cf.get("pgq_queue_name", '')

        # proper params
        if not self.consumer_name:
            self.consumer_name = self.cf.get("consumer_name", self.job_name)
        if not self.queue_name:
            self.queue_name = self.cf.get("queue_name")

        self.stat_batch_start = 0

        # compat vars
        self.pgq_queue_name = self.queue_name
        self.consumer_id = self.consumer_name

        # set default just once
        self.pgq_autocommit = self.cf.getint("pgq_autocommit", self.pgq_autocommit)
        if self.pgq_autocommit and self.pgq_lazy_fetch:
            raise skytools.UsageError("pgq_autocommit is not compatible with pgq_lazy_fetch")
        self.set_database_defaults(self.db_name, autocommit=self.pgq_autocommit)

        self.idle_start = time.time()
Exemplo n.º 7
0
    def work(self):
        """Loop over databases."""

        self.set_single_loop(1)

        self.load_cur_versions()

        # loop over all dbs
        dblst = self.args
        if self.options.all:
            db = self.connect_db('postgres')
            curs = db.cursor()
            curs.execute(DB_LIST)
            dblst = []
            for row in curs.fetchall():
                dblst.append(row[0])
            self.close_database('db')
        elif not dblst:
            raise skytools.UsageError(
                'Give --all or list of database names on command line')

        # loop over connstrs
        for dbname in dblst:
            if self.last_sigint:
                break
            self.log.info("%s: connecting", dbname)
            db = self.connect_db(dbname)
            self.upgrade(dbname, db)
            self.close_database('db')
Exemplo n.º 8
0
    def load_jobs(self):
        self.svc_list = []
        self.svc_map = {}
        self.config_list = []

        # load services
        svc_list = self.cf.sections()
        svc_list.remove(self.service_name)
        with_user = 0
        without_user = 0
        for svc_name in svc_list:
            cf = self.cf.clone(svc_name)
            disabled = cf.getboolean('disabled', 0)
            defscript = None
            if disabled:
                defscript = '/disabled'
            svc = {
                'service': svc_name,
                'script': cf.getfile('script', defscript),
                'cwd': cf.getfile('cwd'),
                'disabled': disabled,
                'args': cf.get('args', ''),
                'user': cf.get('user', ''),
            }
            if svc['user']:
                with_user += 1
            else:
                without_user += 1
            self.svc_list.append(svc)
            self.svc_map[svc_name] = svc
        if with_user and without_user:
            raise skytools.UsageError(
                "Invalid config - some jobs have user=, some don't")

        # generate config list
        for tmp in self.cf.getlist('config_list'):
            tmp = os.path.expanduser(tmp)
            tmp = os.path.expandvars(tmp)
            for fn in glob.glob(tmp):
                self.config_list.append(fn)

        # read jobs
        for fn in self.config_list:
            raw = ConfigParser.SafeConfigParser({
                'job_name': '?',
                'service_name': '?'
            })
            raw.read(fn)

            # skip its own config
            if raw.has_section(self.service_name):
                continue

            got = 0
            for sect in raw.sections():
                if sect in self.svc_map:
                    got = 1
                    self.add_job(fn, sect)
            if not got:
                self.log.warning('Cannot find service for %s', fn)
Exemplo n.º 9
0
    def get_provider_db(self):
        if self.options.copy_node:
            # use custom node for copy
            source_node = self.options.copy_node
            m = self.queue_info.get_member(source_node)
            if not m:
                raise skytools.UsageError("Cannot find node <%s>", source_node)
            if source_node == self.local_node:
                raise skytools.UsageError("Cannot use itself as provider")
            self.provider_location = m.location

        if not self.provider_location:
            db = self.get_database('db')
            q = 'select * from pgq_node.get_node_info(%s)'
            res = self.exec_cmd(db, q, [self.queue_name], quiet = True)
            self.provider_location = res[0]['provider_location']

        return self.get_database('provider_db', connstr = self.provider_location, profile = 'remote')
Exemplo n.º 10
0
    def find_copy_node(self, dst_db, args):
        src_db = self.get_provider_db()

        need = {}
        for t in args:
            need[t] = 1

        while 1:
            src_curs = src_db.cursor()

            q = "select * from pgq_node.get_node_info(%s)"
            src_curs.execute(q, [self.queue_name])
            info = src_curs.fetchone()
            if info['ret_code'] >= 400:
                raise UsageError("Node does not exists")

            self.log.info("Checking if %s can be used for copy", info['node_name'])

            q = "select table_name, local, table_attrs from londiste.get_table_list(%s)"
            src_curs.execute(q, [self.queue_name])
            got = {}
            for row in src_curs.fetchall():
                tbl = row['table_name']
                if tbl not in need:
                    continue
                if not row['local']:
                    self.log.debug("Problem: %s is not local", tbl)
                    continue
                if not self.handler_allows_copy(row['table_attrs']):
                    self.log.debug("Problem: %s handler does not store data [%s]", tbl, row['table_attrs'])
                    continue
                self.log.debug("Good: %s is usable", tbl)
                got[row['table_name']] = 1

            ok = 1
            for t in args:
                if t not in got:
                    self.log.info("Node %s does not have all tables", info['node_name'])
                    ok = 0
                    break

            if ok:
                self.options.copy_node = info['node_name']
                self.log.info("Node %s seems good source, using it", info['node_name'])
                break

            if info['node_type'] == 'root':
                raise skytools.UsageError("Found root and no source found")

            self.close_database('provider_db')
            src_db = self.get_database('provider_db', connstr = info['provider_location'])

        return src_db
Exemplo n.º 11
0
    def __init__(self, args):
        super(DataMaintainer, self).__init__("data_maintainer3", args)

        # source file
        self.fileread = self.cf.get("fileread", "")
        if self.fileread:
            self.fileread = os.path.expanduser(self.fileread)
            self.set_single_loop(True)  # force single run if source is file

        self.csv_delimiter = self.cf.get("csv_delimiter", ',')
        self.csv_quotechar = self.cf.get("csv_quotechar", '"')

        # query for fetching the PK-s of the data set to be maintained
        self.sql_pk = self.cf.get("sql_get_pk_list", "")

        if (int(bool(self.sql_pk)) + int(bool(self.fileread))) in (0, 2):
            raise skytools.UsageError(
                "Either fileread or sql_get_pk_list must be specified in the configuration file"
            )

        # query for changing data tuple ( autocommit )
        self.sql_modify = self.cf.get("sql_modify")

        # query to be run before starting the data maintainer,
        # useful for retrieving initialization parameters of the query
        self.sql_before = self.cf.get("sql_before_run", "")

        # query to be run after finishing the data maintainer
        self.sql_after = self.cf.get("sql_after_run", "")

        # whether to run the sql_after query in case of 0 rows
        self.after_zero_rows = self.cf.getint("after_zero_rows", 1)

        # query to be run if the process crashes
        self.sql_crash = self.cf.get("sql_on_crash", "")

        # query for checking if / how much to throttle
        self.sql_throttle = self.cf.get("sql_throttle", "")

        # how many records to fetch at once
        self.fetchcnt = self.cf.getint("fetchcnt", 100)
        self.fetchcnt = self.cf.getint("fetch_count", self.fetchcnt)

        # specifies if non-transactional cursor should be created (0 -> without hold)
        self.withhold = self.cf.getint("with_hold", 1)

        # execution mode (0 -> whole batch is committed / 1 -> autocommit)
        self.autocommit = self.cf.getint("autocommit", 1)

        # delay in seconds after each commit
        self.commit_delay = self.cf.getfloat("commit_delay", 0.0)
Exemplo n.º 12
0
    def solve_globbing(self, args, full_list, full_map, reverse_map,
                       allow_nonexist):
        def glob2regex(s):
            s = s.replace('.', '[.]').replace('?', '.').replace('*', '.*')
            return '^%s$' % s

        res_map = {}
        res_list = []
        err = 0
        for a in args:
            if a.find('*') >= 0 or a.find('?') >= 0:
                if a.find('.') < 0:
                    a = 'public.' + a
                rc = re.compile(glob2regex(a))
                for x in full_list:
                    if rc.match(x):
                        if not x in res_map:
                            res_map[x] = 1
                            res_list.append(x)
            else:
                a = skytools.fq_name(a)
                if a in res_map:
                    continue
                elif a in full_map:
                    res_list.append(a)
                    res_map[a] = 1
                elif a in reverse_map:
                    self.log.info("%s already processed", a)
                elif allow_nonexist:
                    res_list.append(a)
                    res_map[a] = 1
                elif self.options.force:
                    self.log.warning("%s not available, but --force is used",
                                     a)
                    res_list.append(a)
                    res_map[a] = 1
                else:
                    self.log.warning("%s not available", a)
                    err = 1
        if err:
            raise skytools.UsageError("Cannot proceed")
        return res_list
Exemplo n.º 13
0
Arquivo: jobmgr.py Projeto: postsql/cc
    def start(self, args_extra=[]):
        # unsure about the best way to specify target
        mod = self.jcf.get('module', '')
        script = self.jcf.get('script', '')
        cls = self.jcf.get('class', '')
        args = [self.jcf.filename, self.jname]
        args.extend(args_extra)
        if mod:
            cmd = ['python', '-m', mod] + args
        elif script:
            cmd = [script] + args
        else:
            raise skytools.UsageError(
                'JobState.start: dunno how to launch class')

        self.log.info('Launching %s: %s', self.jname, " ".join(cmd))
        if sys.platform == 'win32':
            p = subprocess.Popen(cmd, close_fds=True)
            self.proc = None
        else:
            cmd.append('-d')
            p = subprocess.Popen(cmd,
                                 close_fds=True,
                                 stdin=open(os.devnull, 'rb'),
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.STDOUT)
            skytools.set_nonblocking(p.stdout, True)
            self.proc = p

        self.start_count += 1
        self.start_time = time.time()
        self.dead_since = None
        self.watchdog_reset = self.jcf.getint('watchdog-reset', 60 * 60)
        self.watchdog_formula_a = self.jcf.getint('watchdog-formula-a', 0)
        self.watchdog_formula_b = self.jcf.getint('watchdog-formula-b', 5)
        self.watchdog_formula_cap = self.jcf.getint('watchdog-formula-cap', 0)
        if self.watchdog_formula_cap <= 0: self.watchdog_formula_cap = None

        self.timer = PeriodicCallback(self.handle_timer, TIMER_TICK * 1000,
                                      self.ioloop)
        self.timer.start()
Exemplo n.º 14
0
    def work(self):
        self.set_single_loop(1)

        hargs = {}
        for a in self.args[1:]:
            if a.find('=') <= 0:
                raise skytools.UsageError('need key=val')
            k, v = a.split('=', 1)
            hargs[k] = v

        tid = str(uuid.uuid4())

        task = TaskSendMessage(req='task.send.' + tid, task_id=tid, **hargs)
        if self.options.sync:
            # sync approach
            ti = self.taskmgr.send_task(task)
            self.log.info('reply: %r', ti.replies)
        else:
            # async approach
            ti = self.taskmgr.send_task_async(task, self.task_cb)
            self.ioloop.start()
            self.log.info('done')
Exemplo n.º 15
0
def find_copy_source(script, queue_name, copy_table_name, node_name,
                     node_location):
    """Find source node for table.

    @param script: DbScript
    @param queue_name: name of the cascaded queue
    @param copy_table_name: name of the table (or list of names)
    @param node_name: target node name
    @param node_location: target node location
    @returns (node_name, node_location, downstream_worker_name) of source node
    """

    # None means no steps upwards were taken, so local consumer is worker
    worker_name = None

    if isinstance(copy_table_name, str):
        need = set([copy_table_name])
    else:
        need = set(copy_table_name)

    while 1:
        src_db = script.get_database('_source_db',
                                     connstr=node_location,
                                     autocommit=1,
                                     profile='remote')
        src_curs = src_db.cursor()

        q = "select * from pgq_node.get_node_info(%s)"
        src_curs.execute(q, [queue_name])
        info = src_curs.fetchone()
        if info['ret_code'] >= 400:
            raise skytools.UsageError("Node does not exist")

        script.log.info("Checking if %s can be used for copy",
                        info['node_name'])

        q = "select table_name, local, table_attrs from londiste.get_table_list(%s)"
        src_curs.execute(q, [queue_name])
        got = set()
        for row in src_curs.fetchall():
            tbl = row['table_name']
            if tbl not in need:
                continue
            if not row['local']:
                script.log.debug("Problem: %s is not local", tbl)
                continue
            if not handler_allows_copy(row['table_attrs']):
                script.log.debug(
                    "Problem: %s handler does not store data [%s]", tbl,
                    row['table_attrs'])
                continue
            script.log.debug("Good: %s is usable", tbl)
            got.add(tbl)

        script.close_database('_source_db')

        if got == need:
            script.log.info("Node %s seems good source, using it",
                            info['node_name'])
            return node_name, node_location, worker_name
        else:
            script.log.info("Node %s does not have all tables",
                            info['node_name'])

        if info['node_type'] == 'root':
            raise skytools.UsageError("Found root and no source found")

        # walk upwards
        node_name = info['provider_node']
        node_location = info['provider_location']
        worker_name = info['worker_name']