示例#1
0
def testsuite():
    """
    Run the CLgen test suite.

    Returns
    -------
    int
        Test return code. 0 if successful.
    """
    with test_env():
        with chdir(module_path()):  # run from module directory
            assert os.path.exists(coveragerc_path())

            args = [
                "--doctest-modules", "--cov=clgen", "--cov-config",
                coveragerc_path()
            ]

            # unless verbose, don't print coverage report
            if log.is_verbose():
                args.append("--verbose")
            else:
                args.append("--cov-report=")

            ret = pytest.main(args)

            assert os.path.exists(coverage_report_path())

        if log.is_verbose():
            print("coverage path:", coverage_report_path())
            print("coveragerc path:", coveragerc_path())

    return ret
示例#2
0
    def run(self) -> None:
        i = dbutil.num_rows_in(self.db_path, "ContentFiles")

        if not log.is_verbose():
            bar = progressbar.ProgressBar(max_value=self.max_i)
            bar.update(self.progress())

        try:
            while True:
                sample_time = time()
                sample = self.queue.get(timeout=60)

                kernels = clutil.get_cl_kernels(sample)
                ids = [crypto.sha1_str(k) for k in kernels]

                if self.sampler_opts["static_checker"]:
                    preprocess_opts = {
                        "use_shim": False,
                        "use_gpuverify": self.sampler_opts["gpuverify"]
                    }
                    pp = [clgen.preprocess_for_db(k, **preprocess_opts)
                          for k in kernels]

                db = dbutil.connect(self.db_path)
                c = db.cursor()

                # insert raw samples
                for kid, src in zip(ids, kernels):
                    dbutil.sql_insert_dict(c, "ContentFiles",
                                           {"id": kid, "contents": src},
                                           ignore_existing=True)

                # insert preprocessed samples
                if self.sampler_opts["static_checker"]:
                    for kid, (status, src) in zip(ids, pp):
                        dbutil.sql_insert_dict(c, "PreprocessedFiles", {
                            "id": kid, "status": status, "contents": src
                        }, ignore_existing=True)

                c.close()
                db.commit()
                db.close()

                # update progress bar
                progress = self.progress()
                if not log.is_verbose():
                    bar.update(progress)

                sample_time = time() - sample_time
                self.sampler.stats["progress"] = progress
                self.sampler.stats["time"] += sample_time
                self.sampler._flush_meta(self.cache)

                # determine if we are done sampling
                if self.term_condition():
                    self.producer.stop()
                    return
        finally:  # always kill the sampler thread
            print()
            self.producer.stop()
示例#3
0
    def run(self) -> None:
        i = dbutil.num_rows_in(self.db_path, "ContentFiles")

        if not log.is_verbose():
            bar = progressbar.ProgressBar(max_value=self.max_i)
            bar.update(self.progress())

        try:
            while True:
                sample_time = time()

                # Block while waiting for a new sample to come in:
                sample = self.queue.get(timeout=120).strip()

                # Compute the sample ID:
                kid = crypto.sha1_str(sample)

                # Add the new sample to the database:
                db = dbutil.connect(self.db_path)
                c = db.cursor()
                dbutil.sql_insert_dict(c,
                                       "ContentFiles", {
                                           "id": kid,
                                           "contents": sample
                                       },
                                       ignore_existing=True)
                c.close()
                db.commit()
                db.close()

                # update progress bar
                progress = self.progress()
                if not log.is_verbose():
                    bar.update(progress)

                sample_time = time() - sample_time
                self.sampler.stats["progress"] = progress
                self.sampler.stats["time"] += sample_time
                self.sampler._flush_meta(self.cache)

                # determine if we are done sampling
                if self.term_condition():
                    self.producer.stop()
                    return
        finally:  # always kill the sampler thread
            print()
            self.producer.stop()
示例#4
0
    def run(self) -> None:
        model = self.model
        batch_size = self.model.corpus.batch_size
        max_length = self.kernel_opts["max_length"]
        temperature = self.kernel_opts["temperature"]

        if model.lock.islocked:  # model is locked during training
            raise lockfile.UnableToAcquireLockError(self.lock)

        tf = model._init_tensorflow(infer=True)

        # seed RNG
        np.random.seed(self.kernel_opts["seed"])
        tf.set_random_seed(self.kernel_opts["seed"])

        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            saver = tf.train.Saver(tf.global_variables())
            ckpt = tf.train.get_checkpoint_state(model.cache.path)

            assert (ckpt)
            assert (ckpt.model_checkpoint_path)

            saver.restore(sess, ckpt.model_checkpoint_path)

            def weighted_pick(weights, temperature):
                """
                requires that all probabilities are >= 0, i.e.:
                  assert all(x >= 0 for x in weights)
                See: https://github.com/ChrisCummins/clgen/issues/120
                """
                t = np.cumsum(weights)
                s = np.sum(weights)
                return int(np.searchsorted(t, np.random.rand(1) * s))

            def get_bracket_depth(text: str, depth: int = 0) -> int:
                """ calculate function block depth """
                # FIXME(polyglot): support multiple sample termination criteria
                depth += text.count("{")
                started = depth > 0
                depth -= text.count("}")
                return started, depth

            init_started, init_depth = get_bracket_depth(self.start_text)
            atomize = model.corpus.atomizer.atomize
            deatomize = model.corpus.atomizer.deatomize

            while not self.stop_requested:
                buf = [StringIO() for _ in range(batch_size)]
                depth = [init_depth] * batch_size
                started = [init_started] * batch_size
                running = [True] * batch_size

                state = sess.run(model.cell.zero_state(batch_size, tf.float32))
                indices = np.zeros((batch_size, 1))

                seed_tensor = atomize(self.start_text)
                for symbol in seed_tensor[:-1]:
                    indices[:] = symbol
                    feed = {
                        model.input_data: indices,
                        model.initial_state: state
                    }
                    [state] = sess.run([model.final_state], feed)

                for item in range(batch_size):
                    buf[item].write(self.start_text)

                indices[:] = seed_tensor[-1]

                for _ in range(max_length):
                    feed = {
                        model.input_data: indices,
                        model.initial_state: state
                    }

                    try:
                        [probs,
                         state] = sess.run([model.probs, model.final_state],
                                           feed)
                    except TensorFlowInvalidArgumentError:
                        log.warning("sampling error")
                        self.run()

                    # sample distribution to pick next symbol:
                    indices[:, 0] = [
                        weighted_pick(p, temperature) for p in probs
                    ]

                    for item in range(batch_size):
                        if not running[item]:
                            continue

                        # In case of decoding error, start sampling again:
                        try:
                            atom = deatomize([indices[item, 0]])
                        except clgen.VocabError:
                            log.warning("deatomizing error")
                            self.run()

                        buf[item].write(atom)
                        # update function block depth
                        _started, depth[item] = get_bracket_depth(
                            atom, depth=depth[item])
                        started[
                            item] |= _started  # you can't "unset" the started state
                        # determine whether to keep sampling:
                        _running = not started[item] or (started[item]
                                                         and depth[item] > 0)
                        running[item] = _running

                        # submit sample to processing queue
                        if not _running:
                            text = buf[item].getvalue()
                            self.queue.put(text)
                            if log.is_verbose():
                                sys.stdout.write(self.sample_header)
                                sys.stdout.write(text)
                                sys.stdout.flush()

                    # start a new batch if there's nothing left running
                    if not any(running):
                        break

            if log.is_verbose():
                sys.stdout.write('\n\n')
示例#5
0
    def run(self) -> None:
        model = self.model
        max_length = self.kernel_opts["max_length"]
        temperature = self.kernel_opts["temperature"]

        if model.lock.islocked:  # model is locked during training
            raise lockfile.UnableToAcquireLockError(self.lock)

        tf = model._init_tensorflow(infer=True)

        # seed RNG
        np.random.seed(self.kernel_opts["seed"])
        tf.set_random_seed(self.kernel_opts["seed"])

        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            saver = tf.train.Saver(tf.global_variables())
            ckpt = tf.train.get_checkpoint_state(model.cache.path)

            assert(ckpt)
            assert(ckpt.model_checkpoint_path)

            saver.restore(sess, ckpt.model_checkpoint_path)

            def weighted_pick(weights, temperature):
                """
                requires that all probabilities are >= 0, i.e.:
                  assert all(x >= 0 for x in weights)
                See: https://github.com/ChrisCummins/clgen/issues/120
                """
                t = np.cumsum(weights)
                s = np.sum(weights)
                return int(np.searchsorted(t, np.random.rand(1) * s))

            def update_bracket_depth(text, started: bool=False, depth: int=0):
                """ calculate function block depth """
                for char in text:
                    if char == '{':
                        depth += 1
                        started = True
                    elif char == '}':
                        depth -= 1

                return started, depth

            init_started, init_depth = update_bracket_depth(self.start_text)

            while not self.stop_requested:
                buf = StringIO()
                started, depth = init_started, init_depth

                state = sess.run(model.cell.zero_state(1, tf.float32))

                seed_tensor = model.corpus.atomizer.atomize(self.start_text)
                for index in seed_tensor[:-1]:
                    x = np.zeros((1, 1))
                    x[0, 0] = index
                    feed = {model.input_data: x, model.initial_state: state}
                    [state] = sess.run([model.final_state], feed)

                buf.write(self.start_text)
                if log.is_verbose():
                    sys.stdout.write("\n\n/* ==== START SAMPLE ==== */\n\n")
                    sys.stdout.write(self.start_text)
                    sys.stdout.flush()

                index = seed_tensor[-1]

                for _ in range(max_length):
                    x = np.zeros((1, 1))
                    x[0, 0] = index
                    feed = {model.input_data: x, model.initial_state: state}
                    [probs, state] = sess.run([model.probs, model.final_state],
                                              feed)
                    p = probs[0]

                    # sample distribution to pick next:
                    index = weighted_pick(p, temperature)
                    # alternatively, select most probable:
                    # index = np.argmax(p)

                    atom = model.corpus.atomizer.deatomize([index])
                    buf.write(atom)
                    if log.is_verbose():
                        sys.stdout.write(atom)

                    # update function block depth
                    started, depth = update_bracket_depth(atom, started, depth)

                    # stop sampling if depth <= 0
                    if started and depth <= 0:
                        break

                # submit sample to processing queue
                self.queue.put(buf.getvalue())

            if log.is_verbose():
                sys.stdout.write('\n\n')
示例#6
0
文件: cli.py 项目: yasutakawada/clgen
def run(method, *args, **kwargs):
    """
    Runs the given method as the main entrypoint to a program.

    If an exception is thrown, print error message and exit.

    If environmental variable DEBUG=1, then exception is not caught.

    Parameters
    ----------
    method : function
        Function to execute.
    *args
        Arguments for method.
    **kwargs
        Keyword arguments for method.

    Returns
    -------
    method(*args, **kwargs)
    """
    def _user_message(exception):
        log.fatal("""\
{err} ({type})

Please report bugs at <https://github.com/ChrisCummins/clgen/issues>\
""".format(err=e, type=type(e).__name__))

    def _user_message_with_stacktrace(exception):
        # get limited stack trace
        def _msg(i, x):
            n = i + 1

            filename = fs.basename(x[0])
            lineno = x[1]
            fnname = x[2]

            loc = "{filename}:{lineno}".format(**vars())
            return "      #{n}  {loc: <18} {fnname}()".format(**vars())

        _, _, tb = sys.exc_info()
        NUM_ROWS = 5  # number of rows in traceback

        trace = reversed(traceback.extract_tb(tb, limit=NUM_ROWS+1)[1:])
        message = "\n".join(_msg(*r) for r in enumerate(trace))

        log.fatal("""\
{err} ({type})

  stacktrace:
{stack_trace}

Please report bugs at <https://github.com/ChrisCummins/clgen/issues>\
""".format(err=e, type=type(e).__name__, stack_trace=message))

    # if DEBUG var set, don't catch exceptions
    if os.environ.get("DEBUG", None):
        # verbose stack traces. see: https://pymotw.com/2/cgitb/
        import cgitb
        cgitb.enable(format='text')

        return method(*args, **kwargs)

    try:
        def runctx():
            return method(*args, **kwargs)

        if prof.is_enabled() and log.is_verbose():
            return cProfile.runctx('runctx()', None, locals(), sort='tottime')
        else:
            return runctx()
    except clgen.UserError as err:
        log.fatal(err, "(" + type(err).__name__  + ")")
    except KeyboardInterrupt:
        sys.stdout.flush()
        sys.stderr.flush()
        print("\nkeyboard interrupt, terminating", file=sys.stderr)
        sys.exit(1)
    except clgen.UserError as e:
        _user_message(e)
    except clgen.File404 as e:
        _user_message(e)
    except Exception as e:
        _user_message_with_stacktrace(e)