Ejemplo n.º 1
0
 def Import(self, session: sqlutil.Session, config: corpus_pb2.Corpus) -> None:
   with self.GetContentFileRoot(config) as contentfile_root:
     relpaths = set(self.GetImportRelpaths(contentfile_root))
     done = set(
         [x[0] for x in session.query(PreprocessedContentFile.input_relpath)])
     todo = relpaths - done
     app.Log(1, 'Preprocessing %s of %s content files',
             humanize.Commas(len(todo)), humanize.Commas(len(relpaths)))
     jobs = [
         internal_pb2.PreprocessorWorker(
             contentfile_root=str(contentfile_root),
             relpath=t,
             preprocessors=config.preprocessor) for t in todo
     ]
     pool = multiprocessing.Pool()
     bar = progressbar.ProgressBar(max_value=len(jobs))
     last_commit = time.time()
     wall_time_start = time.time()
     for preprocessed_cf in bar(pool.imap_unordered(PreprocessorWorker, jobs)):
       wall_time_end = time.time()
       preprocessed_cf.wall_time_ms = (int(
           (wall_time_end - wall_time_start) * 1000))
       wall_time_start = wall_time_end
       session.add(preprocessed_cf)
       if wall_time_end - last_commit > 10:
         session.commit()
         last_commit = wall_time_end
Ejemplo n.º 2
0
  def GetContentFileRoot(self, config: corpus_pb2.Corpus) -> pathlib.Path:
    """Get the path of the directory containing content files.

    If the corpus is a local directory, this simply returns the path. Otherwise,
    this method creates a temporary copy of the files which can be used within
    the scope of this context.

    Args:
      config: The corpus config proto.

    Returns:
      The path of a directory containing content files.
    """
    if config.HasField('local_directory'):
      yield pathlib.Path(ExpandConfigPath(config.local_directory))
    elif config.HasField('local_tar_archive'):
      with tempfile.TemporaryDirectory(prefix='clgen_corpus_') as d:
        start_time = time.time()
        cmd = [
            'tar', '-xf',
            str(ExpandConfigPath(config.local_tar_archive)), '-C', d
        ]
        subprocess.check_call(cmd)
        app.Log(1, 'Unpacked %s in %s ms',
                ExpandConfigPath(config.local_tar_archive).name,
                humanize.Commas(int((time.time() - start_time) * 1000)))
        yield pathlib.Path(d)
    else:
      raise NotImplementedError
Ejemplo n.º 3
0
    def GetLinterErrors(self, abspath: str, relpath: str) -> CacheLookupResult:
        """Looks up the given directory and returns cached results (if any)."""
        relpath_md5 = common.Md5String(relpath).digest()

        # Get the time of the most-recently modified file in the directory.
        checksum = GetDirectoryChecksum(abspath).digest()

        ret = CacheLookupResult(exists=False,
                                checksum=checksum,
                                relpath=relpath,
                                relpath_md5=relpath_md5,
                                errors=[])

        directory = self.session \
          .query(Directory) \
          .filter(Directory.relpath_md5 == ret.relpath_md5) \
          .first()

        if directory and directory.checksum == ret.checksum:
            ret.exists = True
            ret.errors = self.session \
              .query(CachedError) \
              .filter(CachedError.dir == ret.relpath_md5)
        elif directory:
            app.Log(2, "Removing stale directory cache: `%s`", relpath)

            # Delete all existing cache entries.
            self.session.delete(directory)
            self.session \
              .query(CachedError) \
              .filter(CachedError.dir == ret.relpath_md5) \
              .delete()

        return ret
Ejemplo n.º 4
0
def JavaRewrite(text: str) -> str:
    """Run the Java rewriter on the text.

  Args:
    text: The source code to rewrite.

  Returns:
    Source code with identifier names normalized.

  Raises:
    RewriterError: If rewriter found nothing to rewrite.
    ClangTimeout: If rewriter fails to complete within timeout_seconds.
  """
    cmd = ['timeout', '-s9', '60', str(JAVA_REWRITER)]
    process = subprocess.Popen(cmd,
                               stdin=subprocess.PIPE,
                               stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE,
                               universal_newlines=True)
    app.Log(2, '$ %s', ' '.join(cmd))
    stdout, stderr = process.communicate(text)
    if process.returncode == 9:
        raise errors.RewriterException(
            'JavaRewriter failed to complete after 60s')
    elif process.returncode:
        raise errors.RewriterException(stderr)
    return stdout.strip() + '\n'
Ejemplo n.º 5
0
def Javac(text: str,
          class_name: str,
          cflags: typing.List[str],
          timeout_seconds: int = 60) -> str:
    """Run code through javac.

  Args:
    text: The code to compile.
    class_name: The name of the class defined in the file.
    cflags: Additional options passed to javac.
    timeout_seconds: The number of seconds to wait before killing javac.

  Returns:
    The unmodified input code.
  """
    with tempfile.TemporaryDirectory('w', prefix='clgen_javac_') as d:
        path = pathlib.Path(d) / (class_name + '.java')
        with open(path, 'w') as f:
            f.write(text)
        cmd = ['timeout', '-s9',
               str(timeout_seconds), 'javac', f.name] + cflags
        app.Log(2, '$ %s', ' '.join(cmd))
        process = subprocess.Popen(cmd,
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE,
                                   universal_newlines=True)
        stdout, stderr = process.communicate()
    if process.returncode == 9:
        raise errors.BadCodeException(
            f'Javac timed out after {timeout_seconds}s')
    elif process.returncode != 0:
        raise errors.BadCodeException(stderr)
    return text
Ejemplo n.º 6
0
def LlvmBytecodeIterator(
  base_path: pathlib.Path, source_name: str
) -> typing.Iterable[ml4pl_pb2.LlvmBytecode]:
  """Extract LLVM bytecodes from contentfiles.

  Args:
    base_path: The root directory containing IR codes.
    source_name: The name of the source which is attributed to bytecodes.

  Returns:
    An iterator of LlvmBytecode protos.
  """
  for entry in base_path.iterdir():
    if entry.is_dir() and not entry.name.endswith("_preprocessed"):
      for path in entry.iterdir():
        if path.name.endswith(".ll"):
          relpath = os.path.relpath(path, base_path)
          app.Log(1, "Read %s:%s", source_name, relpath)
          yield ml4pl_pb2.LlvmBytecode(
            source_name=source_name,
            relpath=relpath,
            lang="cpp",
            cflags="",
            bytecode=fs.Read(path),
            clang_returncode=0,
            error_message="",
          )
Ejemplo n.º 7
0
def ExtractJavaMethods(text: str, static_only: bool = True) -> typing.List[str]:
  """Extract Java methods from a file.

  Args:
    text: The text of the target file.
    static_only: If true, only static methods are returned.

  Returns:
    A list of method implementations.

  Raises:
    ValueError: In case method extraction fails.
  """
  app.Log(2, '$ %s', JAVA_METHODS_EXTRACTOR)
  process = subprocess.Popen([JAVA_METHODS_EXTRACTOR],
                             stdin=subprocess.PIPE,
                             stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE,
                             universal_newlines=True,
                             env=STATIC_ONLY_ENV if static_only else None)
  stdout, stderr = process.communicate(text)
  if process.returncode:
    raise ValueError("JavaMethodsExtractor exited with non-zero "
                     f"status {process.returncode}")
  methods_list = pbutil.FromString(stdout, scrape_repos_pb2.ListOfStrings())
  return list(methods_list.string)
Ejemplo n.º 8
0
  def GetDependentFiles(
      self,
      target: str,
      excluded_targets: typing.Iterable[str],
  ) -> typing.List[pathlib.Path]:
    """Get the file dependencies of the target.

    Args:
      target: The target to get the dependencies of.

    Returns:
      A list of paths, relative to the root of the workspace.

    Raises:
      OSError: If bazel query fails.
    """
    # First run through bazel query to expand globs.
    bazel = self.BazelQuery([target], stdout=subprocess.PIPE)
    grep = subprocess.Popen(
        ['grep', '^/'],
        stdout=subprocess.PIPE,
        stdin=bazel.stdout,
        universal_newlines=True,
    )

    stdout, _ = grep.communicate()
    if bazel.returncode:
      raise OSError('bazel query failed')
    if grep.returncode:
      raise OSError('grep of bazel query output failed')
    targets = stdout.rstrip().split('\n')

    # Now get the transitive dependencies of each target.
    targets = [target for target in targets if target not in excluded_targets]
    all_targets = targets.copy()
    for i, target in enumerate(targets):
      app.Log(1, 'Collecting transitive deps for target %d of %d: %s', i + 1,
              len(targets), target)
      bazel = self.BazelQuery([f'deps({target})'], stdout=subprocess.PIPE)
      grep = subprocess.Popen(
          ['grep', '^/'],
          stdout=subprocess.PIPE,
          stdin=bazel.stdout,
          universal_newlines=True,
      )

      stdout, _ = grep.communicate()
      if bazel.returncode:
        raise OSError('bazel query failed')
      if grep.returncode:
        raise OSError('grep of bazel query output failed')

      deps = stdout.rstrip().split('\n')
      all_targets += [
          target for target in deps if target not in excluded_targets
      ]

    paths = [self.MaybeTargetToPath(target) for target in all_targets]
    return [path for path in paths if path]
Ejemplo n.º 9
0
def CheckCallOrDie(cmd: typing.List[str]) -> None:
  """Run the given command and exit fatally on error."""
  try:
    app.Log(2, '$ %s', ' '.join(cmd))
    subprocess.check_call(cmd)
  except subprocess.CalledProcessError as e:
    app.FatalWithoutStackTrace("Command: `%s` failed with error: %s",
                               ' '.join(cmd), e)
Ejemplo n.º 10
0
def GreedyAtomizerFromEncodedDb(encoded_db: encoded.EncodedContentFiles):
  """Create a greedy atomizer for the vocabulary of a given encoded_db."""
  # TODO: This depends on the embeded "meta" table vocabulary from:
  # //experimental/deeplearning/deepsmith/java_fuzz/encode_java_corpus.py
  with encoded_db.Session() as s:
    vocab = GetVocabFromMetaTable(s)
  app.Log(1, 'Loaded vocabulary of %s tokens from meta table', len(vocab))
  return atomizers.GreedyAtomizer(vocab)
Ejemplo n.º 11
0
    def ExportToRepo(self,
                     repo: git.Repo,
                     targets: typing.List[str],
                     src_files: typing.List[str],
                     extra_files: typing.List[str],
                     file_move_mapping: typing.Dict[str, str],
                     resume_export: bool = True) -> None:
        """Export the requested targets to the destination directory."""
        # The timestamp for the export.
        timestamp = datetime.datetime.utcnow()

        # Export the git history.
        app.Log(1, 'Exporting git history for %s files',
                humanize.Commas(len(src_files)))
        for file in src_files:
            print(file)

        exported_commit_count = source_tree.ExportGitHistoryForFiles(
            source=self.git_repo,
            destination=repo,
            files_of_interest=src_files,
            resume_export=resume_export)
        if not exported_commit_count:
            return

        # Make manual adjustments.
        exported_workspace = bazelutil.Workspace(
            pathlib.Path(repo.working_tree_dir))
        self.CreatePythonRequirementsFileForTargets(exported_workspace,
                                                    targets)
        self.CopyFilesToDestination(exported_workspace, extra_files)
        self.MoveFilesToDestination(exported_workspace, file_move_mapping)

        if not repo.is_dirty(untracked_files=True):
            return exported_commit_count

        app.Log(1, 'Creating automated subtree export commit')
        repo.git.add('.')
        author = git.Actor(name='[Git export bot]', email='/dev/null')
        repo.index.commit(
            f'Automated subtree export at {timestamp.isoformat()}',
            author=author,
            committer=author,
            skip_hooks=True)
        return exported_commit_count
Ejemplo n.º 12
0
def CoverageContext(
    file_path: str,
    pytest_args: typing.List[str],
) -> typing.List[str]:

    # Record coverage of module under test.
    module = GuessModuleUnderTest(file_path)
    if not module:
        app.Log(1, 'Coverage disabled - no module under test')
        yield pytest_args
        return

    with tempfile.TemporaryDirectory(prefix='phd_test_') as d:
        # If we
        if FLAGS.test_coverage_data_dir:
            datadir = pathlib.Path(FLAGS.test_coverage_data_dir)
            datadir.mkdir(parents=True, exist_ok=True)
        else:
            datadir = pathlib.Path(d)
        # Create a coverage.py config file.
        # See: https://coverage.readthedocs.io/en/coverage-4.3.4/config.html
        config_path = f'{d}/converagerc'
        with open(config_path, 'w') as f:
            f.write(f"""\
[run]
data_file = {datadir}/.coverage
parallel = True
# disable_warnings =
#   module-not-imported
#   no-data-collected
#   module-not-measured

[report]
ignore_errors = True
# Regexes for lines to exclude from consideration
exclude_lines =
  # Have to re-enable the standard pragma
  pragma: no cover

  # Don't complain about missing debug-only code:
  def __repr__
  if self\.debug

  # Don't complain if tests don't hit defensive assertion code:
  raise AssertionError
  raise NotImplementedError

  # Don't complain if non-runnable code isn't run:
  if 0:
  if __name__ == .__main__.:
""")

        pytest_args += [
            f'--cov={module}',
            f'--cov-config={config_path}',
        ]
        yield pytest_args
Ejemplo n.º 13
0
  def GetInferenceModel(self) -> 'keras.models.Sequential':
    """Like training model, but with different batch size."""
    if self._inference_model:
      return self._inference_model

    # Deferred importing of Keras so that we don't have to activate the
    # TensorFlow backend every time we import this module.
    import keras

    app.Log(1, 'Building inference model.')
    model = self.GetTrainingModel()
    config = model.get_config()
    app.Log(1, 'Sampling with batch size %d', sampler.batch_size)
    config[0]['config']['batch_input_shape'] = (sampler.batch_size, 1)
    inference_model = keras.models.Sequential.from_config(config)
    inference_model.trainable = False
    inference_model.set_weights(model.get_weights())
    self._inference_model = inference_model
    self._inference_batch_size = sampler.batch_size
    return inference_model
Ejemplo n.º 14
0
  def Clear(self):
    """Empty the cache.

    If the HashCache was created with keep_in_memory=True, this clears the
    in-memory cache. Note that the in-memory cache is shared between all
    instances of HashCache.
    """
    IN_MEMORY_CACHE.clear()
    with self.Session(commit=True) as session:
      session.query(HashCacheRecord).delete()
    app.Log(2, 'Emptied cache')
Ejemplo n.º 15
0
 def _DoHash(self, absolute_path: pathlib.Path, last_modified: int,
             hash_fn: typing.Callable[[pathlib.Path], str]) -> str:
   with self.Session() as session:
     cached_entry = session.query(HashCacheRecord).filter(
         HashCacheRecord.absolute_path == str(absolute_path)).first()
     if cached_entry and cached_entry.last_modified == last_modified:
       app.Log(2, "Cache hit: '%s'", absolute_path)
       return cached_entry.hash
     elif cached_entry:
       app.Log(2, "Cache miss: '%s'", absolute_path)
       session.delete(cached_entry)
     start_time = time.time()
     checksum = hash_fn(absolute_path)
     app.Log(2, "New cache entry '%s' in %s ms.", absolute_path,
             humanize.Commas(int((time.time() - start_time) * 1000)))
     new_entry = HashCacheRecord(absolute_path=str(absolute_path),
                                 last_modified=last_modified,
                                 hash=checksum)
     session.add(new_entry)
     session.commit()
     return new_entry.hash
Ejemplo n.º 16
0
def Preprocess(contentfiles: pathlib.Path, outdir: pathlib.Path,
               preprocessor_names):
    # Error early if preprocessors are bad.
    [preprocessors.GetPreprocessorFunction(f) for f in preprocessor_names]

    # This is basically the same code as:
    # deeplearning.clgen.corpuses.preprocessed.PreprocessedContentFiles:Import()
    # Only it's writing the results of preprocessing to files rather than to a
    # database. Consider refactoring.
    relpaths = {f.name for f in contentfiles.iterdir()}
    done = {f.name for f in outdir.iterdir()}
    todo = relpaths - done
    app.Log(1, 'Preprocessing %s of %s content files',
            humanize.Commas(len(todo)), humanize.Commas(len(relpaths)))
    jobs = [
        internal_pb2.PreprocessorWorker(contentfile_root=str(contentfiles),
                                        relpath=t,
                                        preprocessors=preprocessor_names)
        for t in todo
    ]
    pool = multiprocessing.Pool()
    bar = progressbar.ProgressBar(max_value=len(jobs))
    wall_time_start = time.time()
    workers = pool.imap_unordered(preprocessed.PreprocessorWorker, jobs)
    succeeded_count = 0
    for preprocessed_cf in bar(workers):
        wall_time_end = time.time()
        preprocessed_cf.wall_time_ms = (int(
            (wall_time_end - wall_time_start) * 1000))
        wall_time_start = wall_time_end
        if preprocessed_cf.preprocessing_succeeded:
            succeeded_count += 1
            with open(outdir / preprocessed_cf.input_relpath, 'w') as f:
                f.write(preprocessed_cf.text)

    app.Log(1, "Successfully preprocessed %s of %s files (%.2f %%)",
            humanize.Commas(succeeded_count), humanize.Commas(len(todo)),
            (succeeded_count / min(len(todo), 1)) * 100)
Ejemplo n.º 17
0
    def _GetIgnoredNames(self, abspath: str) -> typing.Set[str]:
        """Get the set of file names within a directory to ignore."""
        ignore_file_names = set()

        ignore_file = os.path.join(abspath, common.IGNORE_FILE_NAME)
        if os.path.isfile(ignore_file):
            app.Log(2, 'Reading ignore file %s', ignore_file)
            with open(ignore_file) as f:
                for line in f:
                    line = line.split('#')[0].strip()
                    if line:
                        ignore_file_names.add(line)

        return ignore_file_names
Ejemplo n.º 18
0
def Profile(name: str = '',
            print_to: typing.Callable[[str],
                                      None] = lambda msg: app.Log(1, msg)):
    """A context manager which prints the elapsed time upon exit.

  Args:
    name: The name of the task being profiled.
    print_to: The function to print the result to.
  """
    name = name or 'completed'
    start_time = time.time()
    yield
    elapsed = time.time() - start_time
    print_to(f"{name} in {humanize.Duration(elapsed)}")
Ejemplo n.º 19
0
 def _InMemoryWrapper(self, absolute_path: pathlib.Path,
                      last_modified_fn: typing.Callable[[pathlib.Path], int],
                      hash_fn: typing.Callable[[pathlib.Path], str]) -> str:
   """A wrapper around the persistent hashing to support in-memory cache."""
   if self.keep_in_memory:
     in_memory_key = InMemoryCacheKey(self.hash_fn_name, absolute_path)
     if in_memory_key in IN_MEMORY_CACHE:
       app.Log(2, "In-memory cache hit: '%s'", absolute_path)
       return IN_MEMORY_CACHE[in_memory_key]
   hash_ = self._DoHash(absolute_path, last_modified_fn(absolute_path),
                        hash_fn)
   if self.keep_in_memory:
     IN_MEMORY_CACHE[in_memory_key] = hash_
   return hash_
Ejemplo n.º 20
0
def ExportCommitsThatTouchFiles(commits_in_order: typing.List[git.Commit],
                                destiantion: git.Repo,
                                files_of_interest: typing.Set[str]) -> int:
  """Filter and apply the commits that touch the given files of interest.

  The commits are applied in the order provided.
  """
  exported_commit_count = 0
  total_commit_count = humanize.Commas(len(commits_in_order))
  for i, commit in enumerate(commits_in_order):
    app.Log(1, 'Processing commit %s of %s (%.2f%%) %s', humanize.Commas(i + 1),
            total_commit_count, ((i + 1) / len(commits_in_order)) * 100, commit)
    if MaybeExportCommitSubset(commit, destiantion, files_of_interest):
      exported_commit_count += 1
Ejemplo n.º 21
0
def RunPytestOnFileAndExit(file_path: str, argv: typing.List[str]):
    """Run pytest on a file and exit.

  This is invoked by absl.app.RunWithArgs(), and has access to absl flags.

  This function does not return.

  Args:
    file_path: The path of the file to test.
    argv: Positional arguments not parsed by absl. No additional arguments are
      supported.
  """
    if len(argv) > 1:
        raise app.UsageError("Unknown arguments: '{}'.".format(' '.join(
            argv[1:])))

    # Test files must end with _test.py suffix. This is a code style choice, not
    # a hard requirement.
    if not file_path.endswith('_test.py'):
        app.Fatal('File `%s` does not end in suffix _test.py', file_path)

    # Assemble the arguments to run pytest with. Note that the //:conftest file
    # performs some additional configuration not captured here.
    pytest_args = [
        file_path,
        # Run pytest verbosely.
        '-vv',
        '-p',
        'no:cacheprovider',
    ]

    if FLAGS.test_color:
        pytest_args.append('--color=yes')

    if FLAGS.test_maxfail != 0:
        pytest_args.append(f'--maxfail={FLAGS.test_maxfail}')

    # Print the slowest test durations at the end of execution.
    if FLAGS.test_print_durations:
        pytest_args.append(f'--durations={FLAGS.test_durations}')

    # Capture stdout and stderr by default.
    if not FLAGS.test_capture_output:
        pytest_args.append('-s')

    with CoverageContext(file_path, pytest_args) as pytest_args:
        app.Log(1, 'Running pytest with arguments: %s', pytest_args)
        ret = pytest.main(pytest_args)
    sys.exit(ret)
Ejemplo n.º 22
0
    def __init__(self, workspace_: workspace.Workspace,
                 toplevel_dir_relpath: str, dirlinters: typing.List[DirLinter],
                 filelinters: typing.List[FileLinter], timers: Timers):
        super(ToplevelLinter, self).__init__(workspace_)
        self.toplevel_dir = self.workspace.workspace_root / toplevel_dir_relpath
        self.dirlinters = GetLinters(dirlinters, self.workspace)
        self.filelinters = GetLinters(filelinters, self.workspace)
        self.errors_cache = lintercache.LinterCache(self.workspace)
        self.xmp_cache = xmp_cache.XmpCache(self.workspace)
        self.timers = timers

        linter_names = list(
            type(lin).__name__ for lin in self.dirlinters + self.filelinters)
        app.Log(2, "Running //%s linters: %s", self.toplevel_dir,
                ", ".join(linter_names))
Ejemplo n.º 23
0
def ExportGitHistoryForFiles(source: git.Repo,
                             destination: git.Repo,
                             files_of_interest: typing.Set[str],
                             head_ref: str = 'HEAD',
                             resume_export: bool = True) -> int:
    """Apply the parts of the git history from the given source repo """
    if destination.is_dirty():
        raise OSError("Repo `{destination.working_tree_dir}` is dirty")

    with TemporaryGitRemote(destination, source.working_tree_dir) as remote:
        destination.remote(remote).fetch()
        tail = None
        if resume_export:
            tail = MaybeGetHexShaOfLastExportedCommit(destination)
        commits_in_order = GetCommitsInOrder(source,
                                             head_ref=head_ref,
                                             tail_ref=tail)
        if not commits_in_order:
            app.Log(1, 'Nothing to export!')
            return 0
        app.Log(1, 'Exporting history from %s commits',
                humanize.Commas(len(commits_in_order)))
        return ExportCommitsThatTouchFiles(commits_in_order, destination,
                                           files_of_interest)
Ejemplo n.º 24
0
 def Bazel(self,
           command: str,
           args: typing.List[str],
           timeout_seconds: int = 360,
           **subprocess_kwargs):
   cmd = [
       'timeout',
       '-s9',
       str(timeout_seconds),
       'bazel',
       command,
       '--noshow_progress',
   ] + args
   app.Log(2, '$ %s', ' '.join(cmd))
   with fs.chdir(self.workspace_root):
     return subprocess.Popen(cmd, **subprocess_kwargs)
Ejemplo n.º 25
0
    def Import(self, session: sqlutil.Session,
               preprocessed_db: preprocessed.PreprocessedContentFiles,
               atomizer: atomizers.AtomizerBase,
               contentfile_separator: str) -> None:
        with preprocessed_db.Session() as p_session:
            query = p_session.query(
                preprocessed.PreprocessedContentFile).filter(
                    preprocessed.PreprocessedContentFile.
                    preprocessing_succeeded == True,
                    ~preprocessed.PreprocessedContentFile.id.in_(
                        session.query(EncodedContentFile.id).all()))
            jobs = [
                internal_pb2.EncoderWorker(
                    id=x.id,
                    text=x.text,
                    contentfile_separator=contentfile_separator,
                    pickled_atomizer=pickle.dumps(atomizer)) for x in query
            ]
            if not jobs:
                raise errors.EmptyCorpusException(
                    "Pre-processed corpus contains no files: "
                    f"'{preprocessed_db.url}'")

            app.Log(
                1, 'Encoding %s of %s preprocessed files',
                humanize.Commas(query.count()),
                humanize.Commas(
                    p_session.query(
                        preprocessed.PreprocessedContentFile).filter(
                            preprocessed.PreprocessedContentFile.
                            preprocessing_succeeded == True).count()))
            pool = multiprocessing.Pool()
            bar = progressbar.ProgressBar(max_value=len(jobs))
            last_commit = time.time()
            wall_time_start = time.time()
            for encoded_cf in bar(pool.imap_unordered(EncoderWorker, jobs)):
                wall_time_end = time.time()
                # TODO(cec): Remove the if check once EncoderWorker no longer returns
                # None on atomizer encode error.
                if encoded_cf:
                    encoded_cf.wall_time_ms = int(
                        (wall_time_end - wall_time_start) * 1000)
                    session.add(encoded_cf)
                wall_time_start = wall_time_end
                if wall_time_end - last_commit > 10:
                    session.commit()
                    last_commit = wall_time_end
Ejemplo n.º 26
0
    def FindWorkspace(cls, path: pathlib.Path) -> 'Workspace':
        """Look for and return a workspace at or above the current path.

    Args:
      path: The path to a file or directory to start the search from.

    Raises:
      FileNotFoundError: If no workspace is found by the time the nearest mount
        point is reached.
    """
        if (path / 'WORKSPACE').is_file():
            app.Log(2, 'Found workspace: `%s`', path)
            return Workspace(path)
        elif path.is_mount():
            raise FileNotFoundError("Workspace not found")
        else:
            return cls.FindWorkspace(path.parent)
Ejemplo n.º 27
0
def GetTrainingCorpus(
    corpus: 'corpuses.Corpus', training_opts: model_pb2.TrainingOptions
) -> typing.Tuple[np.ndarray, np.ndarray, int]:
    """Get the corpus to train over.

  Args:
    corpus: A Corpus instance.
    training_opts: A TrainingOptions proto.

  Returns:
    An X, y pair of data for an epoch, and the number of steps in the epoch.

  Raises:
    UserError: If batch_size and sequence_length are too large for the corpus,
      yielding no batches.
  """
    start_time = time.time()
    encoded_corpus = corpus.GetTrainingData(
        shuffle=training_opts.shuffle_corpus_contentfiles_between_epochs)
    corpus_length = len(encoded_corpus)
    steps_per_epoch = (corpus_length - 1) // (training_opts.batch_size *
                                              training_opts.sequence_length)
    if not steps_per_epoch:
        raise errors.UserError(
            f'Requested batch size ({training_opts.batch_size}) and '
            f'sequence length ({training_opts.sequence_length}) are too large for '
            f'corpus of size {corpus_length}.')

    clipped_corpus_length = (steps_per_epoch * training_opts.batch_size *
                             training_opts.sequence_length)

    x = np.reshape(encoded_corpus[:clipped_corpus_length], [
        training_opts.batch_size,
        steps_per_epoch * training_opts.sequence_length
    ])
    y = np.reshape(encoded_corpus[1:clipped_corpus_length + 1], [
        training_opts.batch_size,
        steps_per_epoch * training_opts.sequence_length
    ])

    app.Log(1,
            'Encoded corpus of %s tokens (clipped last %s tokens) in %s ms.',
            humanize.Commas(clipped_corpus_length),
            humanize.Commas(corpus_length - clipped_corpus_length),
            humanize.Commas(int((time.time() - start_time) * 1000)))
    return x, y, steps_per_epoch
Ejemplo n.º 28
0
def main(argv):  # pylint: disable=missing-docstring
    paths_to_lint = [pathlib.Path(arg) for arg in argv[1:]]
    if not paths_to_lint:
        raise app.UsageError("Usage: photolint <directory...>")

    for path in paths_to_lint:
        if not path.exists():
            app.FatalWithoutStackTrace(
                f"File or directory not found: '{path}'")

    # Linting is on a per-directory level, not per-file.
    directories_to_lint = {
        path if path.is_dir() else path.parent
        for path in paths_to_lint
    }

    for directory in sorted(directories_to_lint):
        directory = directory.absolute()
        app.Log(2, 'Linting directory `%s`', directory)
        workspace_ = workspace.Workspace.FindWorkspace(directory)
        linters.Lint(workspace_, directory)

    # Print the carriage return once we've done updating the counts line.
    if FLAGS.counts:
        if linters.ERROR_COUNTS:
            print("", file=sys.stderr)
    else:
        linters.PrintErrorCounts(end="\n")

    # Print the profiling timers once we're done.
    if FLAGS.profile:
        total_time = linters.TIMERS.total_seconds
        linting_time = linters.TIMERS.linting_seconds
        cached_time = linters.TIMERS.cached_seconds
        overhead = total_time - linting_time - cached_time

        print(
            f'timings: linting={humanize.Duration(linting_time)} '
            f'({linting_time / total_time:.1%}), '
            f'cached={humanize.Duration(cached_time)} '
            f'({cached_time / total_time:.1%}), '
            f'overhead={humanize.Duration(overhead)} '
            f'({overhead / total_time:.1%}), '
            f'total={humanize.Duration(total_time)}.',
            file=sys.stderr)
Ejemplo n.º 29
0
    def RefreshLintersVersion(self):
        """Check that """
        meta_key = "version"

        cached_version = self.session.query(Meta) \
          .filter(Meta.key == meta_key) \
          .first()
        cached_version_str = (cached_version.value if cached_version else "")

        actual_version = Meta(key=meta_key, value=build_info.Version())

        if cached_version_str != actual_version.value:
            app.Log(1, "Version has changed, emptying cache ...")
            self.Empty(commit=False)
            if cached_version:
                self.session.delete(cached_version)
            self.session.add(actual_version)
            self.session.commit()
Ejemplo n.º 30
0
def ExportGitHistoryForFiles(source: git.Repo,
                             destination: git.Repo,
                             files_of_interest: typing.Set[str],
                             head_ref: str = 'HEAD') -> int:
  """Apply the parts of the git history from the given source repo """
  if destination.is_dirty():
    raise OSError("Repo `{destination.working_tree_dir}` is dirty")

  with TemporaryGitRemote(destination, source.working_tree_dir) as remote:
    destination.remote(remote).fetch()
    tail = MaybeGetHexShaOfLastExportedCommit(destination)
    if tail:
      app.Log(1, 'Resuming export from commit `%s`', tail)
    commits_in_order = GetCommitsInOrder(source,
                                         head_ref=head_ref,
                                         tail_ref=tail)
    return ExportCommitsThatTouchFiles(commits_in_order, destination,
                                       files_of_interest)