コード例 #1
0
ファイル: sampler.py プロジェクト: chubbymaggie/clgen
    def __init__(self, sampler_opts: dict, kernel_opts: dict):
        """
        Instantiate a sampler.

        Arguments:
            sampler_opts (dict): Sampler options.
            kernel_opts (dict): Kernel options.
        """
        assert (type(sampler_opts) is dict)
        assert (type(kernel_opts) is dict)

        self.hash = self._hash(sampler_opts, kernel_opts)

        # parse sampler options
        self.max_kernels = sampler_opts.get("max_kernels", -1)
        self.batch_size = sampler_opts.get("batch_size", 1000)
        self.max_batches = sampler_opts.get("max_batches", -1)
        self.static_checker = sampler_opts.get("static_checker", True)
        self.dynamic_checker = sampler_opts.get("dynamic_checker", False)

        if self.dynamic_checker and not cfg.USE_OPENCL:
            log.warning("dynamic checking requested, but OpenCL not available")
            self.dynamic_checker = False

        self.kernel_opts = kernel_opts
コード例 #2
0
ファイル: _preprocess.py プロジェクト: DhashS/clgen
def preprocess_inplace(paths: List[str],
                       max_num_workers: int = cpu_count(),
                       max_attempts: int = 100,
                       attempt: int = 1) -> None:
    """
    Preprocess a list of files in place.

    Parameters
    ----------
    paths : List[str]
        List of paths.
    max_num_workers : int, optional
        Number of processes to spawn.
    max_attempts : int, optional
        In case of an OSError or TimeoutError, this number of attempts will be
        made.
    """
    if attempt > max_attempts:
        raise clgen.InternalError(
            f"Failed to process files after {max_attempts} attempts")
    elif attempt > 1:
        log.warning("preprocess attempt #.", attempt)

    num_workers = min(len(paths), max_num_workers)

    try:
        log.info('spawned', num_workers, 'worker threads to process',
                 len(paths), 'files ...')
        with clgen.terminating(Pool(num_workers)) as pool:
            pool.map(_preprocess_inplace_worker, paths)
    except (OSError, TimeoutError) as e:
        log.error(e)

        # Try again with fewer threads.
        # See: https://github.com/ChrisCummins/clgen/issues/64
        max_num_workers = max(int(max_num_workers / 2), 1)
        preprocess_inplace(paths,
                           max_num_workers=max_num_workers,
                           attempt=attempt + 1,
                           max_attempts=max_attempts)
コード例 #3
0
ファイル: cli.py プロジェクト: yasutakawada/clgen
        def _main() -> None:
            cache = clgen.cachepath()

            log.warning("Not Implemented: refresh corpuses")

            if fs.isdir(cache, "model"):
                cached_modeldirs = fs.ls(fs.path(cache, "model"), abspaths=True)
                for cached_modeldir in cached_modeldirs:
                    cached_model_id = fs.basename(cached_modeldir)
                    cached_meta = jsonutil.read_file(fs.path(cached_modeldir, "META"))

                    model = clgen.Model.from_json(cached_meta)

                    if cached_model_id != model.hash:
                        log.info(cached_model_id, '->', model.hash)

                        if fs.isdir(model.cache.path):
                            log.fatal("cache conflict", file=sys.stderr)

                        fs.mv(cached_modeldir, model.cache.path)

            log.warning("Not Implemented: refresh samplers")
コード例 #4
0
ファイル: _sampler.py プロジェクト: yasutakawada/clgen
    def run(self) -> None:
        model = self.model
        batch_size = self.model.corpus.batch_size
        max_length = self.kernel_opts["max_length"]
        temperature = self.kernel_opts["temperature"]

        if model.lock.islocked:  # model is locked during training
            raise lockfile.UnableToAcquireLockError(self.lock)

        tf = model._init_tensorflow(infer=True)

        # seed RNG
        np.random.seed(self.kernel_opts["seed"])
        tf.set_random_seed(self.kernel_opts["seed"])

        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            saver = tf.train.Saver(tf.global_variables())
            ckpt = tf.train.get_checkpoint_state(model.cache.path)

            assert (ckpt)
            assert (ckpt.model_checkpoint_path)

            saver.restore(sess, ckpt.model_checkpoint_path)

            def weighted_pick(weights, temperature):
                """
                requires that all probabilities are >= 0, i.e.:
                  assert all(x >= 0 for x in weights)
                See: https://github.com/ChrisCummins/clgen/issues/120
                """
                t = np.cumsum(weights)
                s = np.sum(weights)
                return int(np.searchsorted(t, np.random.rand(1) * s))

            def get_bracket_depth(text: str, depth: int = 0) -> int:
                """ calculate function block depth """
                # FIXME(polyglot): support multiple sample termination criteria
                depth += text.count("{")
                started = depth > 0
                depth -= text.count("}")
                return started, depth

            init_started, init_depth = get_bracket_depth(self.start_text)
            atomize = model.corpus.atomizer.atomize
            deatomize = model.corpus.atomizer.deatomize

            while not self.stop_requested:
                buf = [StringIO() for _ in range(batch_size)]
                depth = [init_depth] * batch_size
                started = [init_started] * batch_size
                running = [True] * batch_size

                state = sess.run(model.cell.zero_state(batch_size, tf.float32))
                indices = np.zeros((batch_size, 1))

                seed_tensor = atomize(self.start_text)
                for symbol in seed_tensor[:-1]:
                    indices[:] = symbol
                    feed = {
                        model.input_data: indices,
                        model.initial_state: state
                    }
                    [state] = sess.run([model.final_state], feed)

                for item in range(batch_size):
                    buf[item].write(self.start_text)

                indices[:] = seed_tensor[-1]

                for _ in range(max_length):
                    feed = {
                        model.input_data: indices,
                        model.initial_state: state
                    }

                    try:
                        [probs,
                         state] = sess.run([model.probs, model.final_state],
                                           feed)
                    except TensorFlowInvalidArgumentError:
                        log.warning("sampling error")
                        self.run()

                    # sample distribution to pick next symbol:
                    indices[:, 0] = [
                        weighted_pick(p, temperature) for p in probs
                    ]

                    for item in range(batch_size):
                        if not running[item]:
                            continue

                        # In case of decoding error, start sampling again:
                        try:
                            atom = deatomize([indices[item, 0]])
                        except clgen.VocabError:
                            log.warning("deatomizing error")
                            self.run()

                        buf[item].write(atom)
                        # update function block depth
                        _started, depth[item] = get_bracket_depth(
                            atom, depth=depth[item])
                        started[
                            item] |= _started  # you can't "unset" the started state
                        # determine whether to keep sampling:
                        _running = not started[item] or (started[item]
                                                         and depth[item] > 0)
                        running[item] = _running

                        # submit sample to processing queue
                        if not _running:
                            text = buf[item].getvalue()
                            self.queue.put(text)
                            if log.is_verbose():
                                sys.stdout.write(self.sample_header)
                                sys.stdout.write(text)
                                sys.stdout.flush()

                    # start a new batch if there's nothing left running
                    if not any(running):
                        break

            if log.is_verbose():
                sys.stdout.write('\n\n')
コード例 #5
0
ファイル: _preprocess.py プロジェクト: DhashS/clgen
def _preprocess_db(db_path: str,
                   max_num_workers: int = cpu_count(),
                   max_attempts: int = 100,
                   attempt: int = 1,
                   **preprocess_opts) -> None:
    """
    Preprocess OpenCL dataset.

    Parameters
    ----------
    db_path : str
        OpenCL kernels dataset.
    max_num_workers : int, optional
        Number of processes to spawn.
    max_attempts : int, optional
        In case of an OSError or TimeoutError, this number of attempts will be
        made.
    """
    if attempt > max_attempts:
        raise clgen.InternalError(
            f"failed to preprocess files after {max_attempts} attempts")

    log.verbose("determining jobs")

    contentfiles = set(dbutil.kernel_ids(db_path, "ContentFiles"))
    preprocessedfiles = set(dbutil.kernel_ids(db_path, "PreprocessedFiles"))

    ncontentfiles = len(contentfiles)
    npreprocessedfiles = len(preprocessedfiles)

    todo = contentfiles - preprocessedfiles
    ntodo = len(todo)

    # check we have something to do
    if not ntodo:
        return

    todo_ratio = ntodo / ncontentfiles

    log.info("{ntodo} ({todo_ratio:.1%}) samples need preprocessing".format(
        **vars()))

    log.verbose("creating jobs")

    # Determine if we need to inline kernels when creating jobs
    db = sqlite3.connect(db_path)
    c = db.cursor()
    c.execute(
        "SELECT name FROM sqlite_master WHERE type='table' AND name='ContentMeta';"
    )
    meta_table = c.fetchone()
    c.close()
    db.close()
    if meta_table:
        get_kernel = lambda kid: dbutil.get_inlined_kernel(
            db_path, kid, lang=preprocess_opts["lang"])
    else:
        get_kernel = lambda kid: dbutil.get_kernel(
            db_path, kid, table="ContentFiles")

    # create jobs
    jobs = [{
        "id": kid,
        "src": get_kernel(kid),
        "preprocess_opts": preprocess_opts,
    } for kid in todo]

    random.shuffle(jobs)

    # split size
    worker_njobs = math.ceil(ntodo / max_num_workers)

    # producer-consumer queue
    queue = Queue(maxsize=128)

    log.verbose(f"assigning {ntodo} jobs to {max_num_workers} threads")

    try:
        # our worker threads. these busy little bees will do the heavy lifting
        # of preprocessing the contentfiles, pushing their results onto
        # the queue
        producers = [
            PreprocessWorker(jobs[i:i + worker_njobs], queue)
            for i in range(0, ntodo, worker_njobs)
        ]

        # fly, my pretties, fly!
        for producer in producers:
            producer.start()

        # consume the results from the worker threads from the main thread
        for i in progressbar.ProgressBar()(range(ntodo)):
            # pull a fresh result from the queue (block if necessary)
            try:
                result = queue.get(timeout=90)
            except QueueEmpty as e:
                raise TimeoutError('failed to fetch result after 90 seconds. '
                                   'something went wrong') from e

            # insert result into database
            db = dbutil.connect(db_path)
            c = db.cursor()
            c.execute("INSERT INTO PreprocessedFiles VALUES(?,?,?)",
                      (result["id"], result["status"], result["contents"]))
            c.close()
            db.commit()
            db.close()

        for producer in producers:
            producer.join()

    except (OSError, TimeoutError) as e:
        log.error(e)

        if attempt > 2 and not i:
            log.warning("no progress has been made since previous attempt. "
                        "I'm not going to try another attempt.")
            return

        # Try again with fewer threads.
        # See: https://github.com/ChrisCummins/clgen/issues/64
        max_num_workers = max(int(max_num_workers / 2), 1)
        _preprocess_db(db_path,
                       max_num_workers=max_num_workers,
                       attempt=attempt + 1,
                       max_attempts=max_attempts,
                       **preprocess_opts)
コード例 #6
0
ファイル: _fetch.py プロジェクト: DhashS/clgen
def fetch_repos(db_path: Path, indir: Path, lang: clgen.Language) -> None:
    db = dbutil.connect(db_path)

    if not dbutil.is_github(db):
        raise clgen.UserError("not a GitHub database")

    c = db.cursor()

    for directory in fs.ls(indir, abspaths=True):
        # hacky hardcoded interpretation of `git remote -v`
        gitdir = fs.path(directory, ".git")
        output = subprocess.check_output(
            ["git", "--git-dir", gitdir, "remote", "-v"],
            universal_newlines=True)
        url = output.split("\n")[0].split("\t")[1].split(" ")[0]
        name = fs.basename(directory)

        output = subprocess.check_output(
            f"git --git-dir {gitdir} rev-list --format=format:'%ai' " +
            f"--max-count=1 $(git --git-dir {gitdir} rev-parse HEAD) | tail -n1",
            shell=True,
            universal_newlines=True)
        try:
            updated_at = dateutil.parser.parse(output)
        except ValueError:
            log.error(f"failed to process {name} {url}")
            continue

        c.execute("SELECT updated_at FROM Repositories WHERE url=?", (url, ))
        cached_updated_at = c.fetchone()

        # Do nothing unless updated timestamps don't match
        # if cached_updated_at and cached_updated_at[0] >= updated_at:
        #     log.verbose(name, "already in database")
        #     continue

        c.execute("DELETE FROM Repositories WHERE url=?", (url, ))
        c.execute("INSERT INTO Repositories VALUES(?,?,?,?,?,?,?,?,?)",
                  (url, "<unknown>", name, 0, 0, 0, 0, updated_at, updated_at))

        name_str = " -o ".join(
            [f"-name '*{ext}'" for ext in clgen.file_extensions(lang)])
        output = subprocess.check_output(
            f"find {directory} -type f {name_str} | grep -v '.git/' || true",
            shell=True,
            universal_newlines=True)
        files = [x.strip() for x in output.split("\n") if x.strip()]

        # nothing to import
        if not len(files):
            # log.verbose("no files in", name)
            continue

        log.verbose("processing", len(files), "files in", name)
        for path in files:
            relpath = path[len(directory) + 1:]
            try:
                contents = inline_fs_headers(path, [], lang=lang)
                sha = crypto.sha1_str(contents)
                c.execute('INSERT OR IGNORE INTO ContentFiles VALUES(?,?)',
                          (sha, contents))
                c.execute(
                    "INSERT OR IGNORE INTO ContentMeta VALUES(?,?,?,?,?)",
                    (sha, relpath, url, sha, len(contents)))
            except UnicodeDecodeError:
                log.warning("non UTF-8 file", path)

        db.commit()
        c = db.cursor()
コード例 #7
0
ファイル: cli.py プロジェクト: yasutakawada/clgen
 def _main() -> None:
     log.warning("not implemented")