예제 #1
0
    def load():
        """
        Try to load a config file from known default locations.

        For the format of the config file, see 'config.sample'
        in the toplevel directory.
        """
        # FIXME: a hack to avoid circular dependencies.
        from utils import bicho_dot_dir, printout

        # First look in /etc
        # FIXME /etc is not portable
        config_file = os.path.join('/etc', 'bicho')
        if os.path.isfile(config_file):
            Config.load_from_file(config_file)

        # Then look at $HOME
        config_file = os.path.join(bicho_dot_dir(), 'config')
        if os.path.isfile(config_file):
            Config.load_from_file(config_file)
        else:
            # If there's an old file, migrate it
            old_config = os.path.join(os.environ.get('HOME'), '.bicho')
            if os.path.isfile(old_config):
                printout("Old config file found in %s, moving to %s",
                         (old_config, config_file))
                os.rename(old_config, config_file)
                Config.load_from_file(config_file)
예제 #2
0
 def run_extension (self, name, extension, repo, uri, db):
     printout ("Executing extension %s", (name,))
     try:
         extension.run (repo, uri, db)
     except ExtensionRunError, e:
         printerr ("Error running extension %s: %s", (name, str (e)))
         return False
예제 #3
0
    def run_extensions(self, repo, uri, db, backout=False):
        done = []
        list = sorted(self.exts)

        for name, extension in [(ext, self.exts[ext]()) for ext in list]:
            if name in done:
                continue
            done.append(name)

            result = True
            # Run dependencies first
            if not self.hard_order and not backout:
                for dep in extension.deps:
                    if dep in done:
                        continue

                    result = self.run_extension(dep, self.exts[dep](), repo,
                                                uri, db)
                    done.append(dep)
                    if not result:
                        break

            if not result:
                printout("Skipping extension %s since one or more of its " + \
                         "dependencies failed", (name,))
                continue

            if not backout:
                self.run_extension(name, extension, repo, uri, db)
            if backout:
                self.backout_extension(name, extension, repo, uri, db)
예제 #4
0
    def run_extensions(self, repo, uri, db, backout=False):
        done = []
        list = sorted(self.exts)
           
        for name, extension in [(ext, self.exts[ext]()) for ext in list]:
            if name in done:
                continue
            done.append(name)

            result = True
            # Run dependencies first
            if not self.hard_order and not backout:
                for dep in extension.deps:
                    if dep in done:
                        continue
                    
                    result = self.run_extension(dep, self.exts[dep](),
                                                    repo, uri, db)
                    done.append(dep)
                    if not result:
                        break

            if not result:
                printout("Skipping extension %s since one or more of its " + \
                         "dependencies failed", (name,))
                continue
            
            if not backout: 
                self.run_extension(name, extension, repo, uri, db)
            if backout:
                self.backout_extension(name, extension, repo, uri, db)
예제 #5
0
async def run(loop):

    async with aiohttp.ClientSession() as session:
        tasks = [loop.create_task(fetch(session, url)) for url in URLS]
        results = await asyncio.gather(*tasks)

    printout(results)
예제 #6
0
def _parse_log(uri, repo, parser, reader, config, db):
    """Parse the log with the given parser, outputting to a database.

    Args:
      uri: The URI of the log to parse (this is already set in the parser)
      repo: The repositoryhandler repository object to query
      parser: The parser object that should be started
      reader: The log reader
      config: The Config object that specifies the current config
      db: The database to add the data to
    """
    # Start the parsing process
    printout("Parsing log for %s (%s)", (uri, repo.get_type()))

    def new_line(line, user_data):
        parser, writer = user_data

        parser.feed(line)
        writer and writer.add_line(line)

    writer = None
    if config.save_logfile is not None:
        writer = LogWriter(config.save_logfile)

    parser.set_content_handler(DBProxyContentHandler(db))
    reader.start(new_line, (parser, writer))
    parser.end()
    writer and writer.close()
예제 #7
0
파일: config.py 프로젝트: davidziman/Bicho
    def load():
        """
        Try to load a config file from known default locations.

        For the format of the config file, see 'config.sample'
        in the toplevel directory.
        """
        # FIXME: a hack to avoid circular dependencies.
        from utils import bicho_dot_dir, printout

        # First look in /etc
        # FIXME /etc is not portable
        config_file = os.path.join('/etc', 'bicho')
        if os.path.isfile(config_file):
            Config.load_from_file(config_file)

        # Then look at $HOME
        config_file = os.path.join(bicho_dot_dir(), 'config')
        if os.path.isfile(config_file):
            Config.load_from_file(config_file)
        else:
            # If there's an old file, migrate it
            old_config = os.path.join(os.environ.get('HOME'), '.bicho')
            if os.path.isfile(old_config):
                printout("Old config file found in %s, moving to %s",
                         (old_config, config_file))
                os.rename(old_config, config_file)
                Config.load_from_file(config_file)
예제 #8
0
 def run_extension(self, name, extension, repo, uri, db):
     printout("Executing extension %s", (name,))
     try:
         extension.run(repo, uri, db)
     except ExtensionRunError, e:
         printerr("Error running extension %s: %s", (name, str(e)))
         return False
예제 #9
0
def _parse_log(uri, repo, parser, reader, config, db):
    """Parse the log with the given parser, outputting to a database.

    Args:
      uri: The URI of the log to parse (this is already set in the parser)
      repo: The repositoryhandler repository object to query
      parser: The parser object that should be started
      reader: The log reader
      config: The Config object that specifies the current config
      db: The database to add the data to
    """
    # Start the parsing process
    printout("Parsing log for %s (%s)", (uri, repo.get_type()))

    def new_line(line, user_data):
        parser, writer = user_data

        parser.feed(line)
        writer and writer.add_line(line)

    writer = None
    if config.save_logfile is not None:
        writer = LogWriter(config.save_logfile)

    parser.set_content_handler(DBProxyContentHandler(db))
    reader.start(new_line, (parser, writer))
    parser.end()
    writer and writer.close()
예제 #10
0
def main():
    results = []

    for url in URLS:
        page = fetch(url)
        results.append(page)

    printout(results)
예제 #11
0
def main():
    results = []

    with requests.session() as session:
        for url in URLS:
            page = fetch(session, url)
            results.append(page)

    printout(results)
예제 #12
0
def main():
    loop = asyncio.get_event_loop()

    tasks = [loop.create_task(fetch(url)) for url in URLS]

    results = loop.run_until_complete(asyncio.gather(*tasks))
    printout(results)

    loop.close()
예제 #13
0
    def repository(self, uri):
        cursor = self.cursor
        cursor.execute(statement("SELECT id from repositories where uri = ?", self.db.place_holder), (uri,))
        self.repo_id = cursor.fetchone()[0]

        last_rev = last_commit = None
        query = "SELECT rev, id from scmlog " + "where id = (select max(id) from scmlog where repository_id = ?)"
        cursor.execute(statement(query, self.db.place_holder), (self.repo_id,))
        rs = cursor.fetchone()
        if rs is not None:
            last_rev, last_commit = rs

        filename = uri.replace("/", "_")
        self.cache_file = os.path.join(cvsanaly_cache_dir(), filename)

        # if there's a previous cache file, just use it
        if os.path.isfile(self.cache_file):
            self.__load_caches_from_disk()

            if last_rev is not None:
                try:
                    commit_id = self.revision_cache[last_rev]
                except KeyError:
                    msg = (
                        "Cache file %s is not up to date or it's corrupt: " % (self.cache_file)
                        + "Revision %s was not found in the cache file" % (last_rev)
                        + "It's not possible to continue, the cache "
                        + "file should be removed and the database cleaned up"
                    )
                    raise CacheFileMismatch(msg)
                if commit_id != last_commit:
                    # Cache and db don't match, removing cache
                    msg = (
                        "Cache file %s is not up to date or it's corrupt: " % (self.cache_file)
                        + "Commit id mismatch for revision %s (File Cache:%d, Database: %d). "
                        % (last_rev, commit_id, last_commit)
                        + "It's not possible to continue, the cache "
                        + "file should be removed and the database cleaned up"
                    )
                    raise CacheFileMismatch(msg)
            else:
                # Database looks empty (or corrupt) and we have
                # a cache file. We can just remove it and continue
                # normally
                self.__init_caches()
                os.remove(self.cache_file)
                printout("Database looks empty, removing cache file %s", (self.cache_file,))
        elif last_rev is not None:
            # There are data in the database,
            # but we don't have a cache file!!!
            msg = (
                "Cache file %s is not up to date or it's corrupt: " % (self.cache_file)
                + "Cache file cannot be found"
                + "It's not possible to continue, the database "
                + "should be cleaned up"
            )
            raise CacheFileMismatch(msg)
예제 #14
0
    def repository(self, uri):
        cursor = self.cursor
        cursor.execute(
            statement("SELECT id from repositories where uri = ?",
                      self.db.place_holder), (uri, ))
        self.repo_id = cursor.fetchone()[0]

        last_rev = last_commit = None
        query = "SELECT rev, id from scmlog " + \
                "where id = (select max(id) from scmlog where repository_id = ?)"
        cursor.execute(statement(query, self.db.place_holder),
                       (self.repo_id, ))
        rs = cursor.fetchone()
        if rs is not None:
            last_rev, last_commit = rs

        filename = uri.replace('/', '_')
        self.cache_file = os.path.join(cvsanaly_cache_dir(), filename)

        # if there's a previous cache file, just use it
        if os.path.isfile(self.cache_file):
            self.__load_caches_from_disk()

            if last_rev is not None:
                try:
                    commit_id = self.revision_cache[last_rev]
                except KeyError:
                    msg = "Cache file %s is not up to date or it's corrupt: " % (self.cache_file) + \
                          "Revision %s was not found in the cache file" % (last_rev) + \
                          "It's not possible to continue, the cache " + \
                          "file should be removed and the database cleaned up"
                    raise CacheFileMismatch(msg)
                if commit_id != last_commit:
                    # Cache and db don't match, removing cache
                    msg = "Cache file %s is not up to date or it's corrupt: " % (self.cache_file) + \
                          "Commit id mismatch for revision %s (File Cache:%d, Database: %d). " % (
                              last_rev, commit_id, last_commit) + \
                          "It's not possible to continue, the cache " + \
                          "file should be removed and the database cleaned up"
                    raise CacheFileMismatch(msg)
            else:
                # Database looks empty (or corrupt) and we have
                # a cache file. We can just remove it and continue
                # normally
                self.__init_caches()
                os.remove(self.cache_file)
                printout("Database looks empty, removing cache file %s",
                         (self.cache_file, ))
        elif last_rev is not None:
            # There are data in the database,
            # but we don't have a cache file!!!
            msg = "Cache file %s is not up to date or it's corrupt: " % (self.cache_file) + \
                  "Cache file cannot be found" + \
                  "It's not possible to continue, the database " + \
                  "should be cleaned up"
            raise CacheFileMismatch(msg)
예제 #15
0
    def add_one_checkcollect_job(self, src_data_dir, dir1, dir2,
                                 recollect_record_name, tar_data_dir, x, y):
        if self.is_running:
            printout(self.flog,
                     'ERROR: cannot add a new job while DataGen is running!')
            exit(1)

        todo = ('CHECKCOLLECT', src_data_dir, recollect_record_name,
                tar_data_dir, np.random.randint(10000000), x, y, dir1, dir2)
        self.todos.append(todo)
예제 #16
0
    def add_one_collect_job(self, data_dir, shape_id, category, cnt_id,
                            primact_type, trial_id):
        if self.is_running:
            printout(self.flog,
                     'ERROR: cannot add a new job while DataGen is running!')
            exit(1)

        todo = ('COLLECT', shape_id, category, cnt_id, primact_type, data_dir,
                trial_id, np.random.randint(10000000))
        self.todos.append(todo)
예제 #17
0
 def begin(self):
     statements = (
         ("tags", """DELETE FROM tags
                    WHERE id IN (SELECT tr.id 
                                 FROM tag_revisions tr, scmlog s
                                 WHERE tr.commit_id = s.id
                                 AND s.repository_id = ?)
                 """),
         ("tag_revisions", """DELETE FROM tag_revisions
                             WHERE commit_id IN (SELECT s.id 
                                                 FROM scmlog s
                                                 WHERE s.repository_id = ?)
                          """),
         ("file_copies", """DELETE FROM file_copies
                           WHERE action_id IN (SELECT a.id 
                                               FROM actions a, scmlog s
                                               WHERE a.commit_id = s.id
                                               AND s.repository_id = ?)
                        """),
         ("branches", """DELETE from branches
                        WHERE id IN (SELECT a.branch_id
                                     FROM actions a, scmlog s
                                     WHERE a.commit_id = s.id
                                     AND s.repository_id = ?)
                     """),
         ("actions", """DELETE FROM actions
                       WHERE commit_id IN (SELECT s.id
                                           FROM scmlog s
                                           WHERE s.repository_id = ?)
                    """),
         ("authors", """DELETE FROM people
                       WHERE id IN (SELECT s.author_id
                                    FROM scmlog s
                                    WHERE s.repository_id = ?)
                    """),
         ("committers", """DELETE FROM people
                          WHERE id IN (SELECT s.committer_id
                                       FROM scmlog s
                                       WHERE s.repository_id = ?)
                       """),
         ("file_links", """DELETE FROM file_links
                          WHERE commit_id IN (SELECT s.id
                                              FROM scmlog s
                                              WHERE s.repository_id = ?)
                       """),
         ("files", """DELETE FROM files WHERE repository_id = ?"""),
         ("commit log", """DELETE FROM scmlog WHERE repository_id = ?"""),
         ("repository", """DELETE FROM repositories WHERE id = ?""")
     )
     
     for (data_name, statement) in statements:
         printout("Deleting " + data_name)
         self.do_delete(statement)
     
     self.connection.commit()
예제 #18
0
    def run_extension(self, name, extension, repo, uri, db):
        # Trim off the ordering numeral before printing
        if self.hard_order:
            name = name[1:]

        printout("Executing extension %s", (name,))
        try:
            extension.run(repo, uri, db)
        except ExtensionRunError, e:
            printerr("Error running extension %s: %s", (name, str(e)))
            return False
예제 #19
0
    def __append_message_line(self, line=None):
        if self.msg_lines <= 0:
            printout("Warning (%d): parsing svn log, unexpected line in message: %s", (self.n_line, line))
            self.msg_lines = 0
            return

        if line is not None:
            self.commit.message += line

        self.commit.message += '\n'
        self.msg_lines -= 1
예제 #20
0
    def run_extension(self, name, extension, repo, uri, db):
        # Trim off the ordering numeral before printing
        if self.hard_order:
            name = name[1:]

        printout("Executing extension %s", (name, ))
        try:
            extension.run(repo, uri, db)
        except ExtensionRunError, e:
            printerr("Error running extension %s: %s", (name, str(e)))
            return False
예제 #21
0
 def backout_extension(self, name, extension, repo, uri, db):
     # Trim off the ordering numeral before printing
     if self.hard_order:
         name = name[1:]
         
     printout("Backing out extension %s", (name,))
     
     try:
         extension.backout(repo, uri, db)
     except (ExtensionRunError, ExtensionBackoutError), e:
         printerr("Error backing out extension %s: %s", (name, str(e)))
         return False
예제 #22
0
    def backout_extension(self, name, extension, repo, uri, db):
        # Trim off the ordering numeral before printing
        if self.hard_order:
            name = name[1:]

        printout("Backing out extension %s", (name, ))

        try:
            extension.backout(repo, uri, db)
        except (ExtensionRunError, ExtensionBackoutError), e:
            printerr("Error backing out extension %s: %s", (name, str(e)))
            return False
예제 #23
0
    def __append_message_line(self, line=None):
        if self.msg_lines <= 0:
            printout("Warning (%d): parsing svn log, unexpected line " + \
                     "in message: %s", (self.n_line, line))
            self.msg_lines = 0
            return

        if line is not None:
            self.commit.message += line

        self.commit.message += '\n'
        self.msg_lines -= 1
예제 #24
0
    def run_extensions(self, repo, uri, db):
        done = []
        for name, extension in [(ext, self.exts[ext]()) for ext in self.exts]:
            if name in done:
                continue
            done.append(name)

            result = True
            result = self.run_extension_deps(extension.deps, repo, uri, db, done)

            if not result:
                printout("Skipping extension %s since one or more of its dependencies failed", (name,))
                continue

            self.run_extension(name, extension, repo, uri, db)
예제 #25
0
    def run_extensions (self, repo, uri, db):
        done = []
        for name, extension in [(ext, self.exts[ext] ()) for ext in self.exts]:
            if name in done:
                continue
            done.append (name)

            result = True
            result = self.run_extension_deps (extension.deps, repo, uri, db, done)

            if not result:
                printout ("Skipping extension %s since one or more of its dependencies failed", (name,))
                continue
                    
            self.run_extension (name, extension, repo, uri, db)
예제 #26
0
    def start_all(self):
        if self.is_running:
            printout(self.flog,
                     'ERROR: cannot start all while DataGen is running!')
            exit(1)

        total_todos = len(self)
        num_todos_per_process = int(np.ceil(total_todos / self.num_processes))
        np.random.shuffle(self.todos)
        for i in range(self.num_processes):
            todos = self.todos[
                i * num_todos_per_process:min(total_todos, (i + 1) *
                                              num_todos_per_process)]
            p = mp.Process(target=self.job_func, args=(i, todos, self.Q))
            p.start()
            self.processes.append(p)

        self.is_running = True
예제 #27
0
    def join_all(self):
        if not self.is_running:
            printout(self.flog,
                     'ERROR: cannot join all while DataGen is idle!')
            exit(1)

        ret = []
        for p in self.processes:
            ret += self.Q.get()

        for p in self.processes:
            p.join()

        self.todos = []
        self.processes = []
        self.Q = mp.Queue()
        self.is_running = False
        return ret
예제 #28
0
    def load(self):
        import os
        from utils import cvsanaly_dot_dir, printout

        # First look in /etc
        # FIXME /etc is not portable
        config_file = os.path.join('/etc', 'cvsanaly2')
        if os.path.isfile(config_file):
            self.__load_from_file(config_file)

        # Then look at $HOME
        config_file = os.path.join(cvsanaly_dot_dir(), 'config')
        if os.path.isfile(config_file):
            self.__load_from_file(config_file)
        else:
            # If there's an old file, migrate it
            old_config = os.path.join(os.environ.get('HOME'), '.cvsanaly')
            if os.path.isfile(old_config):
                printout("Old config file found in %s, moving to %s", (old_config, config_file))
                os.rename(old_config, config_file)
                self.__load_from_file(config_file)
예제 #29
0
파일: Config.py 프로젝트: linzhp/CVSAnalY
    def load(self):
        import os
        from utils import cvsanaly_dot_dir, printout

        # First look in /etc
        # FIXME /etc is not portable
        config_file = os.path.join('/etc', 'cvsanaly2')
        if os.path.isfile(config_file):
            self.__load_from_file(config_file)

        # Then look at $HOME
        config_file = os.path.join(cvsanaly_dot_dir(), 'config')
        if os.path.isfile(config_file):
            self.__load_from_file(config_file)
        else:
            # If there's an old file, migrate it
            old_config = os.path.join(os.environ.get('HOME'), '.cvsanaly')
            if os.path.isfile(old_config):
                printout("Old config file found in %s, moving to %s", (old_config, config_file))
                os.rename(old_config, config_file)
                self.__load_from_file(config_file)
예제 #30
0
    def load ():
        # FIXME: a hack to avoid circular dependencies. 
        from utils import bicho_dot_dir, printout

        # First look in /etc
        # FIXME /etc is not portable
        config_file = os.path.join ('/etc', 'bicho')
        if os.path.isfile (config_file):
            Config.load_from_file (config_file)

        # Then look at $HOME
        config_file = os.path.join (bicho_dot_dir (), 'config')
        if os.path.isfile (config_file):
            Config.load_from_file (config_file) 
        else:
            # If there's an old file, migrate it
            old_config = os.path.join (os.environ.get ('HOME'), '.bicho')
            if os.path.isfile (old_config):
                printout ("Old config file found in %s, moving to %s", \
                              (old_config, config_file))
                os.rename (old_config, config_file)
                Config.load_from_file (config_file)
예제 #31
0
파일: Config.py 프로젝트: iKuba/Bicho
    def load():
        # FIXME: a hack to avoid circular dependencies.
        from utils import bicho_dot_dir, printout

        # First look in /etc
        # FIXME /etc is not portable
        config_file = os.path.join('/etc', 'bicho')
        if os.path.isfile(config_file):
            Config.load_from_file(config_file)

        # Then look at $HOME
        config_file = os.path.join(bicho_dot_dir(), 'config')
        if os.path.isfile(config_file):
            Config.load_from_file(config_file)
        else:
            # If there's an old file, migrate it
            old_config = os.path.join(os.environ.get('HOME'), '.bicho')
            if os.path.isfile(old_config):
                printout ("Old config file found in %s, moving to %s", \
                              (old_config, config_file))
                os.rename(old_config, config_file)
                Config.load_from_file(config_file)
예제 #32
0
    def run_extensions (self, repo, uri, db):
        done = []
        for name, extension in [(ext, self.exts[ext] ()) for ext in self.exts]:
            if name in done:
                continue
            done.append (name)

            result = True
            # Run dependencies first
            for dep in extension.deps:
                if dep in done:
                    continue
                result = self.run_extension (dep, self.exts[dep] (), repo, uri, db)
                done.append (dep)
                if not result:
                    break

            if not result:
                printout ("Skipping extension %s since one or more of its dependencies failed", (name,))
                continue
                    
            self.run_extension (name, extension, repo, uri, db)
예제 #33
0
    def _parse_line (self, line):
        if line is None or line == '':
            return

        # Ignore
        for patt in self.patterns['ignore']:
            if patt.match (line):
                return

        # Commit
        match = self.patterns['commit'].match (line)
        if match:
            if self.commit is not None and self.branch.is_remote ():
                if self.branch.tail.svn_tag is None: # Skip commits on svn tags
                    self.handler.commit (self.branch.tail.commit)

            self.commit = Commit ()
            self.commit.revision = match.group (1)

            parents = match.group (3)
            if parents:
                parents = parents.split ()
            git_commit = self.GitCommit (self.commit, parents)

            decorate = match.group (5)
            branch = None
            if decorate:
                # Remote branch
                m = re.search (self.patterns['branch'], decorate)
                if m:
                    branch = self.GitBranch (self.GitBranch.REMOTE, m.group (1), git_commit)
                    printdbg ("Branch '%s' head at acommit %s", (branch.name, self.commit.revision))
                else:
                    # Local Branch
                    m = re.search (self.patterns['local-branch'], decorate)
                    if m:
                        branch = self.GitBranch (self.GitBranch.LOCAL, m.group (1), git_commit)
                        printdbg ("Commit %s on local branch '%s'", (self.commit.revision, branch.name))
                        # If local branch was merged we just ignore this decoration
                        if self.branch and self.branch.is_my_parent (git_commit):
                            printdbg ("Local branch '%s' was merged", (branch.name,))
                            branch = None
                    else:
                        # Stash
                        m = re.search (self.patterns['stash'], decorate)
                        if m:
                            branch = self.GitBranch (self.GitBranch.STASH, "stash", git_commit)
                            printdbg ("Commit %s on stash", (self.commit.revision,))
                # Tag
                m = re.search (self.patterns['tag'], decorate)
                if m:
                    self.commit.tags = [m.group (1)]
                    printdbg ("Commit %s tagged as '%s'", (self.commit.revision, self.commit.tags[0]))

            if branch is not None and self.branch is not None:
                # Detect empty branches. Ideally, the head of a branch
                # can't have children. When this happens is because the
                # branch is empty, so we just ignore such branch
                if self.branch.is_my_parent (git_commit):
                    printout ("Warning: Detected empty branch '%s', it'll be ignored", (branch.name,))
                    branch = None

            if len (self.branches) >= 2:
                # If current commit is the start point of a new branch
                # we have to look at all the current branches since
                # we haven't inserted the new branch yet.
                # If not, look at all other branches excluding the current one
                for i, b in enumerate (self.branches):
                    if i == 0 and branch is None:
                        continue

                    if b.is_my_parent (git_commit):
                        # We assume current branch is always the last one
                        # AFAIK there's no way to make sure this is right
                        printdbg ("Start point of branch '%s' at commit %s", (self.branches[0].name, self.commit.revision))
                        self.branches.pop (0)
                        self.branch = b

            if self.branch and self.branch.tail.svn_tag is not None and self.branch.is_my_parent (git_commit):
                # There's a pending tag in previous commit
                pending_tag = self.branch.tail.svn_tag
                printdbg ("Move pending tag '%s' from previous commit %s to current %s", (pending_tag,
                                                                                          self.branch.tail.commit.revision,
                                                                                          self.commit.revision))
                if self.commit.tags and pending_tag not in self.commit.tags:
                    self.commit.tags.append (pending_tag)
                else:
                    self.commit.tags = [pending_tag]
                self.branch.tail.svn_tag = None

            if branch is not None:
                self.branch = branch

                # Insert master always at the end
                if branch.is_remote () and branch.name == 'master':
                    self.branches.append (self.branch)
                else:
                    self.branches.insert (0, self.branch)
            else:
                self.branch.set_tail (git_commit)

            return

        # Committer
        match = self.patterns['committer'].match (line)
        if match:
            self.commit.committer = Person ()
            self.commit.committer.name = match.group (1)
            self.commit.committer.email = match.group (2)
            self.handler.committer (self.commit.committer)

            return

        # Author
        match = self.patterns['author'].match (line)
        if match:
            self.commit.author = Person ()
            self.commit.author.name = match.group (1)
            self.commit.author.email = match.group (2)
            self.handler.author (self.commit.author)

            return

        # Date
        match = self.patterns['date'].match (line)
        if match:
            self.commit.date = datetime.datetime (* (time.strptime (match.group (1).strip (" "), "%a %b %d %H:%M:%S %Y")[0:6]))
            # datetime.datetime.strptime not supported by Python2.4
            #self.commit.date = datetime.datetime.strptime (match.group (1).strip (" "), "%a %b %d %H:%M:%S %Y")
            
            return

        # File
        match = self.patterns['file'].match (line)
        if match:
            action = Action ()
            action.type = match.group (1)
            action.f1 = match.group (2)

            self.commit.actions.append (action)
            self.handler.file (action.f1)
        
            return

        # File moved/copied
        match = self.patterns['file-moved'].match (line)
        if match:
            action = Action ()
            type = match.group (1)
            if type == 'R':
                action.type = 'V'
            else:
                action.type = type
            action.f1 = match.group (3)
            action.f2 = match.group (2)
            action.rev = self.commit.revision

            self.commit.actions.append (action)
            self.handler.file (action.f1)

            return

        # This is a workaround for a bug in the GNOME Git migration
        # There are commits on tags not correctly detected like this one:
        # http://git.gnome.org/cgit/evolution/commit/?id=b8e52acac2b9fc5414a7795a73c74f7ee4eeb71f
        # We want to ignore commits on tags since it doesn't make any sense in Git
        if self.is_gnome:
            match = self.patterns['svn-tag'].match (line.strip ())
            if match:
                printout ("Warning: detected a commit on a svn tag: %s", (match.group (0),))
                tag = match.group (1)
                if self.commit.tags and tag in self.commit.tags:
                    # The commit will be ignored, so move the tag
                    # to the next (previous in history) commit
                    self.branch.tail.svn_tag = tag

        # Message
        self.commit.message += line + '\n'

        assert True, "Not match for line %s" % (line)
예제 #34
0
        printerr("Database %s doesn't exist. It must be created before " + \
                 "running MininGit", (config.db_database,))
        return 1
    except DatabaseDriverNotSupported:
        printerr("Database driver %s is not supported by MininGit",
                 (config.db_driver,))
        return 1

    emg = _get_extensions_manager(config.extensions, config.hard_order)

    cnn = db.connect()

    if backout:
        # Run extensions
        #printout(str(get_all_extensions()))
        printout("Backing out all extensions")
        emg.backout_extensions(repo, path or uri, db)
        printout("Backing out repo from database")
        backout_handler = DBDeletionHandler(db, repo, uri, cnn)
        backout_handler.begin()

        # Final commit just in case
        cnn.commit()
        cnn.close()
        return 1

    cursor = cnn.cursor()

    try:
        printdbg("Creating tables")
        db.create_tables(cursor)
예제 #35
0
        return 1
        
    if not db_exists or rep is None:
        # We consider the name of the repo as the last item of the root path
        name = uri.rstrip ("/").split ("/")[-1].strip ()
        cursor = cnn.cursor ()
        rep = DBRepository (None, uri, name, repo.get_type ())
        cursor.execute (statement (DBRepository.__insert__, db.place_holder), (rep.id, rep.uri, rep.name, rep.type))
        cursor.close ()
        cnn.commit ()

    cnn.close ()

    if not config.no_parse:
        # Start the parsing process
        printout ("Parsing log for %s (%s)", (path or uri, repo.get_type ()))
        
        def new_line (line, user_data):
            parser, writer = user_data
        
            parser.feed (line)
            writer and writer.add_line (line)
        
        writer = None
        if config.save_logfile is not None:
            writer = LogWriter (config.save_logfile)
        
        parser.set_content_handler (DBProxyContentHandler (db))
        reader.start (new_line, (parser, writer))
        parser.end ()
        writer and writer.close ()
예제 #36
0
    # control randomness
    if conf.seed < 0:
        conf.seed = random.randint(1, 10000)
    random.seed(conf.seed)
    np.random.seed(conf.seed)
    torch.manual_seed(conf.seed)

    # save config
    torch.save(conf, os.path.join(conf.exp_dir, 'conf.pth'))

    # file log
    flog = open(os.path.join(conf.exp_dir, 'train_log.txt'), 'w')
    conf.flog = flog

    # backup command running
    utils.printout(flog, ' '.join(sys.argv) + '\n')
    utils.printout(flog, f'Random Seed: {conf.seed}')

    # backup python files used for this training
    os.system('cp data.py models/%s.py %s %s' % (conf.model_version, __file__, conf.exp_dir))
     
    # set training device
    device = torch.device(conf.device)
    utils.printout(flog, f'Using device: {conf.device}\n')
    conf.device = device

    # set the max num mask to max num part
    conf.max_num_mask = conf.max_num_parts
    conf.ins_dim = conf.max_num_similar_parts

    ### start training
예제 #37
0
        printerr("Database %s doesn't exist. It must be created before " + \
                 "running MininGit", (config.db_database,))
        return 1
    except DatabaseDriverNotSupported:
        printerr("Database driver %s is not supported by MininGit",
                 (config.db_driver, ))
        return 1

    emg = _get_extensions_manager(config.extensions, config.hard_order)

    cnn = db.connect()

    if backout:
        # Run extensions
        #printout(str(get_all_extensions()))
        printout("Backing out all extensions")
        emg.backout_extensions(repo, path or uri, db)
        printout("Backing out repo from database")
        backout_handler = DBDeletionHandler(db, repo, uri, cnn)
        backout_handler.begin()

        # Final commit just in case
        cnn.commit()
        cnn.close()
        return 1

    cursor = cnn.cursor()

    try:
        printdbg("Creating tables")
        db.create_tables(cursor)
예제 #38
0
    def _parse_line(self, line):
        if line is None or line == '':
            return

        # Ignore
        for patt in self.patterns['ignore']:
            if patt.match(line):
                return

        # Commit
        match = self.patterns['commit'].match(line)
        if match:
            if self.commit is not None:
                # Skip commits on svn tags
                if self.branch.tail.svn_tag is None:
                    self.handler.commit(self.branch.tail.commit)

            self.commit = Commit()
            self.commit.revision = match.group(1)

            parents = match.group(3)
            if parents:
                parents = parents.split()
            git_commit = self.GitCommit(self.commit, parents)

            # If a specific branch has been configured, there
            # won't be any decoration, so a branch needs to be
            # created
            if Config().branch is not None:
                self.branch = self.GitBranch(self.GitBranch.LOCAL,
                                             Config().branch,
                                             git_commit)

            decorate = match.group(5)
            branch = None
            if decorate:
                # Remote branch
                m = re.search(self.patterns['branch'], decorate)
                if m:
                    branch = self.GitBranch(self.GitBranch.REMOTE, m.group(1),
                                            git_commit)
                    printdbg("Branch '%s' head at acommit %s",
                             (branch.name, self.commit.revision))
                else:
                    # Local Branch
                    m = re.search(self.patterns['local-branch'], decorate)
                    if m:
                        branch = self.GitBranch(self.GitBranch.LOCAL,
                                                m.group(1), git_commit)
                        printdbg("Commit %s on local branch '%s'",
                                 (self.commit.revision, branch.name))
                        # If local branch was merged we just ignore this
                        # decoration
                        if self.branch and \
                        self.branch.is_my_parent(git_commit):
                            printdbg("Local branch '%s' was merged",
                                     (branch.name,))
                            branch = None
                    else:
                        # Stash
                        m = re.search(self.patterns['stash'], decorate)
                        if m:
                            branch = self.GitBranch(self.GitBranch.STASH,
                                                    "stash", git_commit)
                            printdbg("Commit %s on stash",
                                     (self.commit.revision,))
                # Tag
                m = re.search(self.patterns['tag'], decorate)
                if m:
                    self.commit.tags = [m.group(1)]
                    printdbg("Commit %s tagged as '%s'",
                             (self.commit.revision, self.commit.tags[0]))

            if branch is not None and self.branch is not None:
                # Detect empty branches. Ideally, the head of a branch
                # can't have children. When this happens is because the
                # branch is empty, so we just ignore such branch
                if self.branch.is_my_parent(git_commit):
                    printout("Warning: Detected empty branch '%s', " + \
                             "it'll be ignored", (branch.name,))
                    branch = None

            if len(self.branches) >= 2:
                # If current commit is the start point of a new branch
                # we have to look at all the current branches since
                # we haven't inserted the new branch yet.
                # If not, look at all other branches excluding the current one
                for i, b in enumerate(self.branches):
                    if i == 0 and branch is None:
                        continue

                    if b.is_my_parent(git_commit):
                        # We assume current branch is always the last one
                        # AFAIK there's no way to make sure this is right
                        printdbg("Start point of branch '%s' at commit %s",
                                 (self.branches[0].name, self.commit.revision))
                        self.branches.pop(0)
                        self.branch = b

            if self.branch and self.branch.tail.svn_tag is not None and \
            self.branch.is_my_parent(git_commit):
                # There's a pending tag in previous commit
                pending_tag = self.branch.tail.svn_tag
                printdbg("Move pending tag '%s' from previous commit %s " + \
                         "to current %s", (pending_tag,
                                           self.branch.tail.commit.revision,
                                           self.commit.revision))
                if self.commit.tags and pending_tag not in self.commit.tags:
                    self.commit.tags.append(pending_tag)
                else:
                    self.commit.tags = [pending_tag]
                self.branch.tail.svn_tag = None

            if branch is not None:
                self.branch = branch

                # Insert master always at the end
                if branch.is_remote() and branch.name == 'master':
                    self.branches.append(self.branch)
                else:
                    self.branches.insert(0, self.branch)
            else:
                self.branch.set_tail(git_commit)

            if parents and len(parents) > 1 and not Config().analyze_merges:
                #Skip merge commits
                self.commit = None

            return
        elif self.commit is None:
            return

        # Committer
        match = self.patterns['committer'].match(line)
        if match:
            self.commit.committer = Person()
            self.commit.committer.name = match.group(1)
            self.commit.committer.email = match.group(2)
            self.handler.committer(self.commit.committer)

            return

        # Author
        match = self.patterns['author'].match(line)
        if match:
            self.commit.author = Person()
            self.commit.author.name = match.group(1)
            self.commit.author.email = match.group(2)
            self.handler.author(self.commit.author)

            return

        # Commit Date
        match = self.patterns['commit-date'].match(line)
        if match:
            self.commit.commit_date = datetime.datetime(*(time.strptime(\
                match.group(1).strip(" "), "%a %b %d %H:%M:%S %Y")[0:6]))

            return

        # Author Date
        match = self.patterns['author-date'].match(line)
        if match:
            self.commit.author_date = datetime.datetime(*(time.strptime(\
                match.group(1).strip(" "), "%a %b %d %H:%M:%S %Y")[0:6]))

            return

        # File
        match = self.patterns['file'].match(line)
        if match:
            action = Action()
            action.type = match.group(1)
            action.f1 = match.group(2)

            self.commit.actions.append(action)
            self.handler.file(action.f1)

            return

        # File moved/copied
        match = self.patterns['file-moved'].match(line)
        if match:
            action = Action()
            type = match.group(1)
            if type == 'R':
                action.type = 'V'
            else:
                action.type = type
            action.f1 = match.group(3)
            action.f2 = match.group(2)
            action.rev = self.commit.revision

            self.commit.actions.append(action)
            self.handler.file(action.f1)

            return

        # Message
        self.commit.message += line + '\n'

        assert True, "Not match for line %s" % (line)
예제 #39
0
    def _parse_line(self, line):
        if line is None or line == '':
            return

        # Ignore
        for patt in self.patterns['ignore']:
            if patt.match(line):
                return

        # Commit
        match = self.patterns['commit'].match(line)
        if match:
            if self.commit is not None and self.branch is not None:
                if self.branch.tail.svn_tag is None:  # Skip commits on svn tags
                    self.handler.commit(self.branch.tail.commit)

            self.commit = Commit()
            self.commit.revision = match.group(1)

            parents = match.group(3)
            if parents:
                parents = parents.split()
                self.commit.parents = parents
            git_commit = self.GitCommit(self.commit, parents)

            decorate = match.group(5)
            branch = None
            if decorate:
                # Remote branch
                m = re.search(self.patterns['branch'], decorate)
                if m:
                    branch = self.GitBranch(self.GitBranch.REMOTE, m.group(1), git_commit)
                    printdbg("Branch '%s' head at acommit %s", (branch.name, self.commit.revision))
                else:
                    # Local Branch
                    m = re.search(self.patterns['local-branch'], decorate)
                    if m:
                        branch = self.GitBranch(self.GitBranch.LOCAL, m.group(1), git_commit)
                        printdbg("Commit %s on local branch '%s'", (self.commit.revision, branch.name))
                        # If local branch was merged we just ignore this decoration
                        if self.branch and self.branch.is_my_parent(git_commit):
                            printdbg("Local branch '%s' was merged", (branch.name,))
                            branch = None
                    else:
                        # Stash
                        m = re.search(self.patterns['stash'], decorate)
                        if m:
                            branch = self.GitBranch(self.GitBranch.STASH, "stash", git_commit)
                            printdbg("Commit %s on stash", (self.commit.revision,))
                # Tag
                m = re.search(self.patterns['tag'], decorate)
                if m:
                    self.commit.tags = [m.group(1)]
                    printdbg("Commit %s tagged as '%s'", (self.commit.revision, self.commit.tags[0]))

            if not branch and not self.branch:
                branch = self.GitBranch(self.GitBranch.LOCAL, "(no-branch)", git_commit)
                printdbg("Commit %s on unknown local branch '%s'", (self.commit.revision, branch.name))

            # This part of code looks wired at first time so here is a small description what it does:
            #
            # * self.branch is the branch to which the last inspected commit belonged to
            # * branch is the branch of the current parsed commit
            #
            # This check is only to find branches which are fully merged into a already analyzed branch
            #
            # For more detailed information see https://github.com/MetricsGrimoire/CVSAnalY/issues/64
            if branch is not None and self.branch is not None:
                # Detect empty branches.
                # Ideally, the head of a branch can't have children.
                # When this happens is because the branch is empty, so we just ignore such branch.
                if self.branch.is_my_parent(git_commit):
                    printout(
                        "Info: Branch '%s' will be ignored, because it was already merged in an active one.",
                        (branch.name,)
                    )
                    branch = None

            if len(self.branches) >= 2:
                # If current commit is the start point of a new branch
                # we have to look at all the current branches since
                # we haven't inserted the new branch yet.
                # If not, look at all other branches excluding the current one
                for i, b in enumerate(self.branches):
                    if i == 0 and branch is None:
                        continue

                    if b.is_my_parent(git_commit):
                        # We assume current branch is always the last one
                        # AFAIK there's no way to make sure this is right
                        printdbg("Start point of branch '%s' at commit %s",
                                 (self.branches[0].name, self.commit.revision))
                        self.branches.pop(0)
                        self.branch = b

            if self.branch and self.branch.tail.svn_tag is not None and self.branch.is_my_parent(git_commit):
                # There's a pending tag in previous commit
                pending_tag = self.branch.tail.svn_tag
                printdbg("Move pending tag '%s' from previous commit %s to current %s", (pending_tag,
                                                                                         self.branch.tail.commit.revision,
                                                                                         self.commit.revision))
                if self.commit.tags and pending_tag not in self.commit.tags:
                    self.commit.tags.append(pending_tag)
                else:
                    self.commit.tags = [pending_tag]
                self.branch.tail.svn_tag = None

            if branch is not None:
                self.branch = branch

                # Insert master always at the end
                if branch.name == 'master':
                    self.branches.append(self.branch)
                else:
                    self.branches.insert(0, self.branch)
            else:
                if self.branch is not None:
                    self.branch.set_tail(git_commit)
            return

        # Committer
        match = self.patterns['committer'].match(line)
        if match:
            self.commit.committer = Person()
            self.commit.committer.name = match.group(1)
            self.commit.committer.email = match.group(2)
            self.handler.committer(self.commit.committer)
            return

        # Author
        match = self.patterns['author'].match(line)
        if match:
            self.commit.author = Person()
            self.commit.author.name = match.group(1)
            self.commit.author.email = match.group(2)
            self.handler.author(self.commit.author)
            return

        # Commit date
        match = self.patterns['date'].match(line)
        if match:
            self.commit.date = datetime.datetime(
                *(time.strptime(match.group(1).strip(" "), "%a %b %d %H:%M:%S %Y")[0:6]))
            # datetime.datetime.strptime not supported by Python2.4
            #self.commit.date = datetime.datetime.strptime (match.group (1).strip (" "), "%a %b %d %H:%M:%S %Y")

            # match.group(2) represents the timezone. E.g. -0300, +0200, +0430 (Afghanistan)
            # This string will be parsed to int and recalculated into seconds (60 * 60)
            self.commit.date_tz = (((int(match.group(2))) * 60 * 60) / 100)
            return

        # Author date
        match = self.patterns['author_date'].match(line)
        if match:
            self.commit.author_date = datetime.datetime(
                *(time.strptime(match.group(1).strip(" "), "%a %b %d %H:%M:%S %Y")[0:6]))
            # datetime.datetime.strptime not supported by Python2.4
            #self.commit.author_date = datetime.datetime.strptime (match.group (1).strip (" "), "%a %b %d %H:%M:%S %Y")

            # match.group(2) represents the timezone. E.g. -0300, +0200, +0430 (Afghanistan)
            # This string will be parsed to int and recalculated into seconds (60 * 60)
            self.commit.author_date_tz = (((int(match.group(2))) * 60 * 60) / 100)
            return

        # File
        match = self.patterns['file'].match(line)
        if match:
            action = Action()
            type = match.group(1)
            if len(type) > 1:
                # merge actions
                if 'M' in type:
                    type = 'M'
                else:
                    # ignore merge actions without 'M'
                    return

            action.type = type
            action.f1 = match.group(2)

            self.commit.actions.append(action)
            self.handler.file(action.f1)
            return

        # File moved/copied
        match = self.patterns['file-moved'].match(line)
        if match:
            action = Action()
            type = match.group(1)
            if type == 'R':
                action.type = 'V'
            else:
                action.type = type
            action.f1 = match.group(3)
            action.f2 = match.group(2)
            action.rev = self.commit.revision

            self.commit.actions.append(action)
            self.handler.file(action.f1)

            return

        # This is a workaround for a bug in the GNOME Git migration
        # There are commits on tags not correctly detected like this one:
        # http://git.gnome.org/cgit/evolution/commit/?id=b8e52acac2b9fc5414a7795a73c74f7ee4eeb71f
        # We want to ignore commits on tags since it doesn't make any sense in Git
        if self.is_gnome:
            match = self.patterns['svn-tag'].match(line.strip())
            if match:
                printout("Warning: detected a commit on a svn tag: %s", (match.group(0),))
                tag = match.group(1)
                if self.commit.tags and tag in self.commit.tags:
                    # The commit will be ignored, so move the tag
                    # to the next (previous in history) commit
                    self.branch.tail.svn_tag = tag

        # Message
        self.commit.message += line + '\n'

        assert True, "Not match for line %s" % (line)
예제 #40
0
    np.random.seed(conf.seed)
    torch.manual_seed(conf.seed)

    # save config
    if not conf.resume:
        torch.save(conf, os.path.join(conf.exp_dir, 'conf.pth'))

    # file log
    if conf.resume:
        flog = open(os.path.join(conf.exp_dir, 'train_log.txt'), 'a+')
    else:
        flog = open(os.path.join(conf.exp_dir, 'train_log.txt'), 'w')
    conf.flog = flog

    # backup command running
    utils.printout(flog, ' '.join(sys.argv) + '\n')
    utils.printout(flog, f'Random Seed: {conf.seed}')

    # backup python files used for this training
    if not conf.resume:
        os.system('cp datagen.py data.py models/%s.py %s %s' %
                  (conf.model_version, __file__, conf.exp_dir))

    # set training device
    device = torch.device(conf.device)
    utils.printout(flog, f'Using device: {conf.device}\n')
    conf.device = device

    # parse params
    utils.printout(flog, 'primact_type: %s' % str(conf.primact_type))
예제 #41
0
def train(conf, train_shape_list, train_data_list, val_data_list,
          all_train_data_list):
    # create training and validation datasets and data loaders
    data_features = ['pcs', 'pc_pxids', 'pc_movables', 'gripper_img_target', 'gripper_direction_camera', 'gripper_forward_direction_camera', \
            'result', 'cur_dir', 'shape_id', 'trial_id', 'is_original']

    # load network model
    model_def = utils.get_model_module(conf.model_version)

    # create models
    network = model_def.Network(conf.feat_dim)
    utils.printout(conf.flog, '\n' + str(network) + '\n')

    # create optimizers
    network_opt = torch.optim.Adam(network.parameters(),
                                   lr=conf.lr,
                                   weight_decay=conf.weight_decay)

    # learning rate scheduler
    network_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        network_opt, step_size=conf.lr_decay_every, gamma=conf.lr_decay_by)

    # create logs
    if not conf.no_console_log:
        header = '     Time    Epoch     Dataset    Iteration    Progress(%)       LR    TotalLoss'
    if not conf.no_tb_log:
        # https://github.com/lanpa/tensorboard-pytorch
        from tensorboardX import SummaryWriter
        train_writer = SummaryWriter(os.path.join(conf.exp_dir, 'train'))
        val_writer = SummaryWriter(os.path.join(conf.exp_dir, 'val'))

    # send parameters to device
    network.to(conf.device)
    utils.optimizer_to_device(network_opt, conf.device)

    # load dataset
    train_dataset = SAPIENVisionDataset([conf.primact_type], conf.category_types, data_features, conf.buffer_max_num, \
            abs_thres=conf.abs_thres, rel_thres=conf.rel_thres, dp_thres=conf.dp_thres, img_size=conf.img_size, no_true_false_equal=conf.no_true_false_equal)

    val_dataset = SAPIENVisionDataset([conf.primact_type], conf.category_types, data_features, conf.buffer_max_num, \
            abs_thres=conf.abs_thres, rel_thres=conf.rel_thres, dp_thres=conf.dp_thres, img_size=conf.img_size, no_true_false_equal=conf.no_true_false_equal)
    val_dataset.load_data(val_data_list)
    utils.printout(conf.flog, str(val_dataset))

    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=conf.batch_size, shuffle=False, pin_memory=True, \
            num_workers=0, drop_last=True, collate_fn=utils.collate_feats, worker_init_fn=utils.worker_init_fn)
    val_num_batch = len(val_dataloader)

    # create a data generator
    datagen = DataGen(conf.num_processes_for_datagen, conf.flog)

    # sample succ
    if conf.sample_succ:
        sample_succ_list = []
        sample_succ_dirs = []

    # start training
    start_time = time.time()

    last_train_console_log_step, last_val_console_log_step = None, None

    # if resume
    start_epoch = 0
    if conf.resume:
        # figure out the latest epoch to resume
        for item in os.listdir(os.path.join(conf.exp_dir, 'ckpts')):
            if item.endswith('-train_dataset.pth'):
                start_epoch = int(item.split('-')[0])

        # load states for network, optimizer, lr_scheduler, sample_succ_list
        data_to_restore = torch.load(
            os.path.join(conf.exp_dir, 'ckpts',
                         '%d-network.pth' % start_epoch))
        network.load_state_dict(data_to_restore)
        data_to_restore = torch.load(
            os.path.join(conf.exp_dir, 'ckpts',
                         '%d-optimizer.pth' % start_epoch))
        network_opt.load_state_dict(data_to_restore)
        data_to_restore = torch.load(
            os.path.join(conf.exp_dir, 'ckpts',
                         '%d-lr_scheduler.pth' % start_epoch))
        network_lr_scheduler.load_state_dict(data_to_restore)

        # rmdir and make a new dir for the current sample-succ directory
        old_sample_succ_dir = os.path.join(
            conf.data_dir, 'epoch-%04d_sample-succ' % (start_epoch - 1))
        utils.force_mkdir(old_sample_succ_dir)

    # train for every epoch
    for epoch in range(start_epoch, conf.epochs):
        ### collect data for the current epoch
        if epoch > start_epoch:
            utils.printout(
                conf.flog,
                f'  [{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} Waiting epoch-{epoch} data ]'
            )
            train_data_list = datagen.join_all()
            utils.printout(
                conf.flog,
                f'  [{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} Gathered epoch-{epoch} data ]'
            )
            cur_data_folders = []
            for item in train_data_list:
                item = '/'.join(item.split('/')[:-1])
                if item not in cur_data_folders:
                    cur_data_folders.append(item)
            for cur_data_folder in cur_data_folders:
                with open(os.path.join(cur_data_folder, 'data_tuple_list.txt'),
                          'w') as fout:
                    for item in train_data_list:
                        if cur_data_folder == '/'.join(item.split('/')[:-1]):
                            fout.write(item.split('/')[-1] + '\n')

            # load offline-generated sample-random data
            for item in all_train_data_list:
                valid_id_l = conf.num_interaction_data_offline + conf.num_interaction_data * (
                    epoch - 1)
                valid_id_r = conf.num_interaction_data_offline + conf.num_interaction_data * epoch
                if valid_id_l <= int(item.split('_')[-1]) < valid_id_r:
                    train_data_list.append(item)

        ### start generating data for the next epoch
        # sample succ
        if conf.sample_succ:
            if conf.resume and epoch == start_epoch:
                sample_succ_list = torch.load(
                    os.path.join(conf.exp_dir, 'ckpts',
                                 '%d-sample_succ_list.pth' % start_epoch))
            else:
                torch.save(
                    sample_succ_list,
                    os.path.join(conf.exp_dir, 'ckpts',
                                 '%d-sample_succ_list.pth' % epoch))
            for item in sample_succ_list:
                datagen.add_one_recollect_job(item[0], item[1], item[2],
                                              item[3], item[4], item[5],
                                              item[6])
            sample_succ_list = []
            sample_succ_dirs = []
            cur_sample_succ_dir = os.path.join(
                conf.data_dir, 'epoch-%04d_sample-succ' % epoch)
            utils.force_mkdir(cur_sample_succ_dir)

        # start all jobs
        datagen.start_all()
        utils.printout(
            conf.flog,
            f'  [ {strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} Started generating epoch-{epoch+1} data ]'
        )

        ### load data for the current epoch
        if conf.resume and epoch == start_epoch:
            train_dataset = torch.load(
                os.path.join(conf.exp_dir, 'ckpts',
                             '%d-train_dataset.pth' % start_epoch))
        else:
            train_dataset.load_data(train_data_list)
        utils.printout(conf.flog, str(train_dataset))
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=True, pin_memory=True, \
                num_workers=0, drop_last=True, collate_fn=utils.collate_feats, worker_init_fn=utils.worker_init_fn)
        train_num_batch = len(train_dataloader)

        ### print log
        if not conf.no_console_log:
            utils.printout(conf.flog, f'training run {conf.exp_name}')
            utils.printout(conf.flog, header)

        train_batches = enumerate(train_dataloader, 0)
        val_batches = enumerate(val_dataloader, 0)

        train_fraction_done = 0.0
        val_fraction_done = 0.0
        val_batch_ind = -1

        ### train for every batch
        for train_batch_ind, batch in train_batches:
            train_fraction_done = (train_batch_ind + 1) / train_num_batch
            train_step = epoch * train_num_batch + train_batch_ind

            log_console = not conf.no_console_log and (last_train_console_log_step is None or \
                    train_step - last_train_console_log_step >= conf.console_log_interval)
            if log_console:
                last_train_console_log_step = train_step

            # save checkpoint
            if train_batch_ind == 0:
                with torch.no_grad():
                    utils.printout(conf.flog, 'Saving checkpoint ...... ')
                    torch.save(
                        network.state_dict(),
                        os.path.join(conf.exp_dir, 'ckpts',
                                     '%d-network.pth' % epoch))
                    torch.save(
                        network_opt.state_dict(),
                        os.path.join(conf.exp_dir, 'ckpts',
                                     '%d-optimizer.pth' % epoch))
                    torch.save(
                        network_lr_scheduler.state_dict(),
                        os.path.join(conf.exp_dir, 'ckpts',
                                     '%d-lr_scheduler.pth' % epoch))
                    torch.save(
                        train_dataset,
                        os.path.join(conf.exp_dir, 'ckpts',
                                     '%d-train_dataset.pth' % epoch))
                    utils.printout(conf.flog, 'DONE')

            # set models to training mode
            network.train()

            # forward pass (including logging)
            total_loss, whole_feats, whole_pcs, whole_pxids, whole_movables = forward(batch=batch, data_features=data_features, network=network, conf=conf, is_val=False, \
                    step=train_step, epoch=epoch, batch_ind=train_batch_ind, num_batch=train_num_batch, start_time=start_time, \
                    log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=train_writer, lr=network_opt.param_groups[0]['lr'])

            # optimize one step
            network_opt.zero_grad()
            total_loss.backward()
            network_opt.step()
            network_lr_scheduler.step()

            # sample succ
            if conf.sample_succ:
                network.eval()

                with torch.no_grad():
                    # sample a random EE orientation
                    random_up = torch.randn(conf.batch_size,
                                            3).float().to(conf.device)
                    random_forward = torch.randn(conf.batch_size,
                                                 3).float().to(conf.device)
                    random_left = torch.cross(random_up, random_forward)
                    random_forward = torch.cross(random_left, random_up)
                    random_dirs1 = F.normalize(random_up, dim=1).float()
                    random_dirs2 = F.normalize(random_forward, dim=1).float()

                    # test over the entire image
                    whole_pc_scores1 = network.inference_whole_pc(
                        whole_feats, random_dirs1, random_dirs2)  # B x N
                    whole_pc_scores2 = network.inference_whole_pc(
                        whole_feats, -random_dirs1, random_dirs2)  # B x N

                    # add to the sample_succ_list if wanted
                    ss_cur_dir = batch[data_features.index('cur_dir')]
                    ss_shape_id = batch[data_features.index('shape_id')]
                    ss_trial_id = batch[data_features.index('trial_id')]
                    ss_is_original = batch[data_features.index('is_original')]
                    for i in range(conf.batch_size):
                        valid_id_l = conf.num_interaction_data_offline + conf.num_interaction_data * (
                            epoch - 1)
                        valid_id_r = conf.num_interaction_data_offline + conf.num_interaction_data * epoch

                        if ('sample-succ' not in ss_cur_dir[i]) and (ss_is_original[i]) and (ss_cur_dir[i] not in sample_succ_dirs) \
                                and (valid_id_l <= int(ss_trial_id[i]) < valid_id_r):
                            sample_succ_dirs.append(ss_cur_dir[i])

                            # choose one from the two options
                            gt_movable = whole_movables[i].cpu().numpy()

                            whole_pc_score1 = whole_pc_scores1[i].cpu().numpy(
                            ) * gt_movable
                            whole_pc_score1[whole_pc_score1 < 0.5] = 0
                            whole_pc_score_sum1 = np.sum(
                                whole_pc_score1) + 1e-12

                            whole_pc_score2 = whole_pc_scores2[i].cpu().numpy(
                            ) * gt_movable
                            whole_pc_score2[whole_pc_score2 < 0.5] = 0
                            whole_pc_score_sum2 = np.sum(
                                whole_pc_score2) + 1e-12

                            choose1or2_ratio = whole_pc_score_sum1 / (
                                whole_pc_score_sum1 + whole_pc_score_sum2)
                            random_dir1 = random_dirs1[i].cpu().numpy()
                            random_dir2 = random_dirs2[i].cpu().numpy()
                            if np.random.random() < choose1or2_ratio:
                                whole_pc_score = whole_pc_score1
                            else:
                                whole_pc_score = whole_pc_score2
                                random_dir1 = -random_dir1

                            # sample <X, Y> on each img
                            pp = whole_pc_score + 1e-12
                            ptid = np.random.choice(len(whole_pc_score),
                                                    1,
                                                    p=pp / pp.sum())
                            X = whole_pxids[i, ptid, 0].item()
                            Y = whole_pxids[i, ptid, 1].item()

                            # add job to the queue
                            str_cur_dir1 = ',' + ','.join(
                                ['%f' % elem for elem in random_dir1])
                            str_cur_dir2 = ',' + ','.join(
                                ['%f' % elem for elem in random_dir2])
                            sample_succ_list.append((conf.offline_data_dir, str_cur_dir1, str_cur_dir2, \
                                    ss_cur_dir[i].split('/')[-1], cur_sample_succ_dir, X, Y))

            # validate one batch
            while val_fraction_done <= train_fraction_done and val_batch_ind + 1 < val_num_batch:
                val_batch_ind, val_batch = next(val_batches)

                val_fraction_done = (val_batch_ind + 1) / val_num_batch
                val_step = (epoch + val_fraction_done) * train_num_batch - 1

                log_console = not conf.no_console_log and (last_val_console_log_step is None or \
                        val_step - last_val_console_log_step >= conf.console_log_interval)
                if log_console:
                    last_val_console_log_step = val_step

                # set models to evaluation mode
                network.eval()

                with torch.no_grad():
                    # forward pass (including logging)
                    __ = forward(batch=val_batch, data_features=data_features, network=network, conf=conf, is_val=True, \
                            step=val_step, epoch=epoch, batch_ind=val_batch_ind, num_batch=val_num_batch, start_time=start_time, \
                            log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=val_writer, lr=network_opt.param_groups[0]['lr'])
예제 #42
0
def train(conf):
    # create training and validation datasets and data loaders
    data_features = ['img', 'pts', 'ins_one_hot' , 'box_size', 'total_parts_cnt' , 'similar_parts_cnt', 'mask' ,'shape_id', 'view_id']
    
    train_dataset = PartNetShapeDataset(conf.category, conf.data_dir, data_features, data_split="train", \
            max_num_mask = conf.max_num_parts, max_num_similar_parts=conf.max_num_similar_parts, img_size=conf.img_size, on_kaichun_machine=conf.on_kaichun_machine)
    utils.printout(conf.flog, str(train_dataset))
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=True, pin_memory=True, \
            num_workers=conf.num_workers, drop_last=True, collate_fn=utils.collate_feats_with_none, worker_init_fn=utils.worker_init_fn)
    
    val_dataset = PartNetShapeDataset(conf.category, conf.data_dir, data_features, data_split="val", \
            max_num_mask = conf.max_num_parts, max_num_similar_parts=conf.max_num_similar_parts, img_size=conf.img_size, on_kaichun_machine=conf.on_kaichun_machine)
    utils.printout(conf.flog, str(val_dataset))
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=conf.batch_size, shuffle=False, pin_memory=True, \
            num_workers=0, drop_last=True, collate_fn=utils.collate_feats_with_none, worker_init_fn=utils.worker_init_fn)

    # load network model
    model_def = utils.get_model_module(conf.model_version)

    # create models
    network = model_def.Network(conf, train_dataset.get_part_count())
    utils.printout(conf.flog, '\n' + str(network) + '\n')

    models = [network]
    model_names = ['network']

    # create optimizers
    network_opt = torch.optim.Adam(network.parameters(), lr=conf.lr, weight_decay=conf.weight_decay)
    optimizers = [network_opt]
    optimizer_names = ['network_opt']

    # learning rate scheduler
    network_lr_scheduler = torch.optim.lr_scheduler.StepLR(network_opt, step_size=conf.lr_decay_every, gamma=conf.lr_decay_by)

    # create logs
    if not conf.no_console_log:
        header = '     Time    Epoch     Dataset    Iteration    Progress(%)       LR    CenterLoss    QuatLoss   TotalLoss'
    if not conf.no_tb_log:
        # https://github.com/lanpa/tensorboard-pytorch
        from tensorboardX import SummaryWriter
        train_writer = SummaryWriter(os.path.join(conf.exp_dir, 'train'))
        val_writer = SummaryWriter(os.path.join(conf.exp_dir, 'val'))

    # send parameters to device
    for m in models:
        m.to(conf.device)
    for o in optimizers:
        utils.optimizer_to_device(o, conf.device)

    # start training
    start_time = time.time()

    last_checkpoint_step = None
    last_train_console_log_step, last_val_console_log_step = None, None
    train_num_batch = len(train_dataloader)
    val_num_batch = len(val_dataloader)

    # train for every epoch
    for epoch in range(conf.epochs):
        if not conf.no_console_log:
            utils.printout(conf.flog, f'training run {conf.exp_name}')
            utils.printout(conf.flog, header)

        train_batches = enumerate(train_dataloader, 0)
        val_batches = enumerate(val_dataloader, 0)
        train_fraction_done = 0.0
        val_fraction_done = 0.0
        val_batch_ind = -1

        # train for every batch
        for train_batch_ind, batch in train_batches:
            train_fraction_done = (train_batch_ind + 1) / train_num_batch
            train_step = epoch * train_num_batch + train_batch_ind

            log_console = not conf.no_console_log and (last_train_console_log_step is None or \
                    train_step - last_train_console_log_step >= conf.console_log_interval)
            if log_console:
                last_train_console_log_step = train_step

            # set models to training mode
            for m in models:
                m.train()

            # forward pass (including logging)
            total_loss = forward(batch=batch, data_features=data_features, network=network, conf=conf, is_val=False, \
                    step=train_step, epoch=epoch, batch_ind=train_batch_ind, num_batch=train_num_batch, start_time=start_time, \
                    log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=train_writer, lr=network_opt.param_groups[0]['lr'])

            if total_loss is not None:
                # optimize one step
                network_lr_scheduler.step()
                network_opt.zero_grad()
                total_loss.backward()
                network_opt.step()

            # save checkpoint
            with torch.no_grad():
                if last_checkpoint_step is None or train_step - last_checkpoint_step >= conf.checkpoint_interval:
                    utils.printout(conf.flog, 'Saving checkpoint ...... ')
                    utils.save_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.exp_dir, 'ckpts'), \
                            epoch=epoch, prepend_epoch=True, optimizers=optimizers, optimizer_names=model_names)
                    utils.printout(conf.flog, 'DONE')
                    last_checkpoint_step = train_step

            # validate one batch
            while val_fraction_done <= train_fraction_done and val_batch_ind+1 < val_num_batch:
                val_batch_ind, val_batch = next(val_batches)

                val_fraction_done = (val_batch_ind + 1) / val_num_batch
                val_step = (epoch + val_fraction_done) * train_num_batch - 1

                log_console = not conf.no_console_log and (last_val_console_log_step is None or \
                        val_step - last_val_console_log_step >= conf.console_log_interval)
                if log_console:
                    last_val_console_log_step = val_step

                # set models to evaluation mode
                for m in models:
                    m.eval()

                with torch.no_grad():
                    # forward pass (including logging)
                    __ = forward(batch=val_batch, data_features=data_features, network=network, conf=conf, is_val=True, \
                            step=val_step, epoch=epoch, batch_ind=val_batch_ind, num_batch=val_num_batch, start_time=start_time, \
                            log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=val_writer, lr=network_opt.param_groups[0]['lr'])
           
    # save the final models
    utils.printout(conf.flog, 'Saving final checkpoint ...... ')
    utils.save_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.exp_dir, 'ckpts'), \
            epoch=epoch, prepend_epoch=False, optimizers=optimizers, optimizer_names=optimizer_names)
    utils.printout(conf.flog, 'DONE')
예제 #43
0
def forward(batch, data_features, network, conf, \
        is_val=False, step=None, epoch=None, batch_ind=0, num_batch=1, start_time=0, \
        log_console=False, log_tb=False, tb_writer=None, lr=None):
    # prepare input
    input_pcs = torch.cat(batch[data_features.index('pcs')],
                          dim=0).to(conf.device)  # B x 3N x 3
    input_pxids = torch.cat(batch[data_features.index('pc_pxids')],
                            dim=0).to(conf.device)  # B x 3N x 2
    input_movables = torch.cat(batch[data_features.index('pc_movables')],
                               dim=0).to(conf.device)  # B x 3N
    batch_size = input_pcs.shape[0]

    input_pcid1 = torch.arange(batch_size).unsqueeze(1).repeat(
        1, conf.num_point_per_shape).long().reshape(-1)  # BN
    input_pcid2 = furthest_point_sample(
        input_pcs, conf.num_point_per_shape).long().reshape(-1)  # BN
    input_pcs = input_pcs[input_pcid1,
                          input_pcid2, :].reshape(batch_size,
                                                  conf.num_point_per_shape, -1)
    input_pxids = input_pxids[input_pcid1,
                              input_pcid2, :].reshape(batch_size,
                                                      conf.num_point_per_shape,
                                                      -1)
    input_movables = input_movables[input_pcid1, input_pcid2].reshape(
        batch_size, conf.num_point_per_shape)

    input_dirs1 = torch.cat(
        batch[data_features.index('gripper_direction_camera')],
        dim=0).to(conf.device)  # B x 3
    input_dirs2 = torch.cat(
        batch[data_features.index('gripper_forward_direction_camera')],
        dim=0).to(conf.device)  # B x 3

    # forward through the network
    pred_result_logits, pred_whole_feats = network(
        input_pcs, input_dirs1, input_dirs2)  # B x 2, B x F x N

    # prepare gt
    gt_result = torch.Tensor(batch[data_features.index('result')]).long().to(
        conf.device)  # B
    gripper_img_target = torch.cat(
        batch[data_features.index('gripper_img_target')],
        dim=0).to(conf.device)  # B x 3 x H x W

    # for each type of loss, compute losses per data
    result_loss_per_data = network.critic.get_ce_loss(pred_result_logits,
                                                      gt_result)

    # for each type of loss, compute avg loss per batch
    result_loss = result_loss_per_data.mean()

    # compute total loss
    total_loss = result_loss

    # display information
    data_split = 'train'
    if is_val:
        data_split = 'val'

    with torch.no_grad():
        # log to console
        if log_console:
            utils.printout(conf.flog, \
                f'''{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} '''
                f'''{epoch:>5.0f}/{conf.epochs:<5.0f} '''
                f'''{data_split:^10s} '''
                f'''{batch_ind:>5.0f}/{num_batch:<5.0f} '''
                f'''{100. * (1+batch_ind+num_batch*epoch) / (num_batch*conf.epochs):>9.1f}%      '''
                f'''{lr:>5.2E} '''
                f'''{total_loss.item():>10.5f}''')
            conf.flog.flush()

        # log to tensorboard
        if log_tb and tb_writer is not None:
            tb_writer.add_scalar('total_loss', total_loss.item(), step)
            tb_writer.add_scalar('lr', lr, step)

        # gen visu
        if is_val and (
                not conf.no_visu) and epoch % conf.num_epoch_every_visu == 0:
            visu_dir = os.path.join(conf.exp_dir, 'val_visu')
            out_dir = os.path.join(visu_dir, 'epoch-%04d' % epoch)
            input_pc_dir = os.path.join(out_dir, 'input_pc')
            gripper_img_target_dir = os.path.join(out_dir,
                                                  'gripper_img_target')
            info_dir = os.path.join(out_dir, 'info')

            if batch_ind == 0:
                # create folders
                os.mkdir(out_dir)
                os.mkdir(input_pc_dir)
                os.mkdir(gripper_img_target_dir)
                os.mkdir(info_dir)

            if batch_ind < conf.num_batch_every_visu:
                utils.printout(conf.flog, 'Visualizing ...')
                for i in range(batch_size):
                    fn = 'data-%03d.png' % (batch_ind * batch_size + i)
                    render_utils.render_pts(os.path.join(
                        BASE_DIR, input_pc_dir, fn),
                                            input_pcs[i].cpu().numpy(),
                                            highlight_id=0)
                    cur_gripper_img_target = (
                        gripper_img_target[i].permute(1, 2, 0).cpu().numpy() *
                        255).astype(np.uint8)
                    Image.fromarray(cur_gripper_img_target).save(
                        os.path.join(gripper_img_target_dir, fn))
                    with open(
                            os.path.join(info_dir, fn.replace('.png', '.txt')),
                            'w') as fout:
                        fout.write('cur_dir: %s\n' %
                                   batch[data_features.index('cur_dir')][i])
                        fout.write('pred: %s\n' % utils.print_true_false(
                            (pred_result_logits[i] > 0).cpu().numpy()))
                        fout.write(
                            'gt: %s\n' %
                            utils.print_true_false(gt_result[i].cpu().numpy()))
                        fout.write('result_loss: %f\n' %
                                   result_loss_per_data[i].item())

            if batch_ind == conf.num_batch_every_visu - 1:
                # visu html
                utils.printout(conf.flog, 'Generating html visualization ...')
                sublist = 'input_pc,gripper_img_target,info'
                cmd = 'cd %s && python %s . 10 htmls %s %s > /dev/null' % (
                    out_dir,
                    os.path.join(BASE_DIR, 'gen_html_hierachy_local.py'),
                    sublist, sublist)
                call(cmd, shell=True)
                utils.printout(conf.flog, 'DONE')

    return total_loss, pred_whole_feats.detach(), input_pcs.detach(
    ), input_pxids.detach(), input_movables.detach()
예제 #44
0
    if not db_exists or rep is None:
        # We consider the name of the repo as the last item of the root path
        name = uri.rstrip("/").split("/")[-1].strip()
        cursor = cnn.cursor()
        rep = DBRepository(None, uri, name, repo.get_type())
        cursor.execute(statement(DBRepository.__insert__, db.place_holder),
                       (rep.id, rep.uri, rep.name, rep.type))
        cursor.close()
        cnn.commit()

    cnn.close()

    if not config.no_parse:
        # Start the parsing process
        printout("Parsing log for %s (%s)", (path or uri, repo.get_type()))

        def new_line(line, user_data):
            parser, writer = user_data

            parser.feed(line)
            writer and writer.add_line(line)

        writer = None
        if config.save_logfile is not None:
            writer = LogWriter(config.save_logfile)

        parser.set_content_handler(DBProxyContentHandler(db))
        reader.start(new_line, (parser, writer))
        parser.end()
        writer and writer.close()
예제 #45
0
    def _parse_line(self, line):
        if not line:
            if self.commit is not None and self.state == SVNParser.COMMIT \
            or self.state == SVNParser.FILES:
                self.state = SVNParser.MESSAGE
            elif self.state == SVNParser.MESSAGE:
                self.__append_message_line()
                
            return

        # Message
        if self.state == SVNParser.MESSAGE and self.msg_lines > 0:
            self.__append_message_line(line)

            return
        
        # Invalid commit. Some svn repos like asterisk have commits like this:
        # r176840 | (no author) | (no date) | 1 line
        # without any canged path, so I think we can just ignore them
        if self.patterns['invalid'].match(line):
            printdbg("SVN Parser: skipping invalid commit: %s", (line,))
            self.state = SVNParser.COMMIT
            self.commit = None
            return
        
        # Separator
        if self.patterns['separator'].match(line):
            if self.commit is None or self.state == SVNParser.COMMIT:
                return
            elif self.state == SVNParser.MESSAGE \
            or self.state == SVNParser.FILES:
                # We can go directly from FILES to COMMIT
                # when there is an empty log message
                if self.msg_lines > 0:
                    printout("Warning (%d): parsing svn log, missing " + \
                             "lines in commit message!", (self.n_line,))
                
                self.__convert_commit_actions(self.commit)
                self.handler.commit(self.commit)
                self.state = SVNParser.COMMIT
                self.commit = None
                self.msg_lines = 0
            else:
                printout("Warning (%d): parsing svn log, unexpected separator", 
                         (self.n_line,))
                
            return

        # Commit
        match = self.patterns['commit'].match(line)
        if match and self.state == SVNParser.COMMIT:
            commit = Commit()
            commit.revision = match.group(1)
            
            commit.committer = Person()
            commit.committer.name = match.group(2)
            
            commit.commit_date = datetime.datetime(int(match.group(3)),
                                                   int(match.group(4)),
                                                   int(match.group(5)),
                                                   int(match.group(6)),
                                                   int(match.group(7)),
                                                   int(match.group(8)))
            self.msg_lines = int(match.group(10))
            self.commit = commit
            self.handler.committer(commit.committer)
            
            return
        elif match and self.state == SVNParser.MESSAGE:
            # It seems a piece of a log message has been copied as
            # part of the commit message
            self.commit.message += line + '\n'
            return
        elif match and self.state != SVNParser.COMMIT:
            printout("Warning (%d): parsing svn log, unexpected line %s", 
                     (self.n_line, line))
            return

        # Files
        if self.state == SVNParser.COMMIT:
            if self.patterns['paths'].match(line):
                self.state = SVNParser.FILES
            else:
                printout("Warning(%d): parsing svn log, unexpected line %s", 
                         (self.n_line, line))

            return
        
        # File moved/copied/replaced
        match = self.patterns['file-moved'].match(line)
        if match:
            if self.state != SVNParser.FILES:
                printout("Warning (%d): parsing svn log, unexpected line %s", 
                         (self.n_line, line))
                return
            
            action = Action()
            action.type = match.group(1)
            action.f1 = match.group(2)
            action.f2 = match.group(3)
            action.rev = match.group(4)

            action.branch_f1 = self.__guess_branch_from_path(action.f1)
            action.branch_f2 = self.__guess_branch_from_path(action.f2)

            self.commit.actions.append(action)
            self.handler.file(action.f1)

            return

        # File
        match = self.patterns['file'].match(line)
        if match:
            if self.state != SVNParser.FILES:
                printout("Warning (%d): parsing svn log, unexpected line %s", 
                         (self.n_line, line))
                return
            
            path = match.group(2)

            if path != '/':
                # path == '/' is probably a properties change in /
                # not interesting for us, ignoring

                action = Action()
                action.type = match.group(1)
                action.f1 = path

                action.branch_f1 = self.__guess_branch_from_path(path)

                self.commit.actions.append(action)
                self.handler.file(path)

            return
예제 #46
0
def forward(batch, data_features, network, conf, \
        is_val=False, step=None, epoch=None, batch_ind=0, num_batch=1, start_time=0, \
        log_console=False, log_tb=False, tb_writer=None, lr=None):
    # prepare input
    # generate a batch of data size  < 64
    batch_index = 1
    if len(batch) == 0:
        return None

    cur_batch_size = len(batch[data_features.index('total_parts_cnt')])
    total_part_cnt = batch[data_features.index('total_parts_cnt')][0]

    if total_part_cnt == 1:
        print('passed an entire shape does not work for batch norm')
        return None
    input_total_part_cnt = batch[data_features.index('total_parts_cnt')][0]                             # 1
    input_img = batch[data_features.index('img')][0]                                                    # 3 x H x W
    input_img = input_img.repeat(input_total_part_cnt, 1, 1, 1)                            # part_cnt 3 x H x W
    input_pts = batch[data_features.index('pts')][0].squeeze(0)[:input_total_part_cnt]                             # part_cnt x N x 3
    input_ins_one_hot = batch[data_features.index('ins_one_hot')][0].squeeze(0)[:input_total_part_cnt]             # part_cnt x max_similar_parts
    input_similar_part_cnt = batch[data_features.index('similar_parts_cnt')][0].squeeze(0)[:input_total_part_cnt]  # part_cnt x 1    
    input_box_size = batch[data_features.index('box_size')][0].squeeze(0)[:input_total_part_cnt]

    # prepare gt: 
    gt_mask = (batch[data_features.index('mask')][0].squeeze(0)[:input_total_part_cnt].to(conf.device),)  
    input_total_part_cnt = [batch[data_features.index('total_parts_cnt')][0]]
    while total_part_cnt < 32 and batch_index < cur_batch_size:
        cur_input_cnt = batch[data_features.index('total_parts_cnt')][batch_index]
        total_part_cnt += cur_input_cnt
        if total_part_cnt > 40:
            total_part_cnt -= cur_input_cnt
            batch_index += 1
            continue
        cur_batch_img = batch[data_features.index('img')][batch_index].repeat(cur_input_cnt, 1, 1, 1)
        input_img = torch.cat((input_img, cur_batch_img), dim=0)
        cur_box_size = batch[data_features.index('box_size')][batch_index].squeeze(0)[:cur_input_cnt]
        input_box_size = torch.cat( (input_box_size, cur_box_size), dim=0)   
        input_pts = torch.cat((input_pts, batch[data_features.index('pts')][batch_index].squeeze(0)[:cur_input_cnt]), dim=0)                            # B x max_parts x N x 3
        input_ins_one_hot = torch.cat((input_ins_one_hot, batch[data_features.index('ins_one_hot')][batch_index].squeeze(0)[:cur_input_cnt]), dim=0)    # B x max_parts x max_similar_parts
        input_total_part_cnt.append(batch[data_features.index('total_parts_cnt')][batch_index])                             # 1
        input_similar_part_cnt = torch.cat((input_similar_part_cnt, batch[data_features.index('similar_parts_cnt')][batch_index].squeeze(0)[:cur_input_cnt]), dim=0)  # B x max_parts x 2    
        # prepare gt
        gt_mask = gt_mask + (batch[data_features.index('mask')][batch_index].squeeze(0)[:cur_input_cnt].to(conf.device), )
        batch_index += 1

    input_img = input_img.to(conf.device); input_pts = input_pts.to(conf.device); # input_sem_one_hot = input_sem_one_hot.to(conf.device); 
    input_similar_part_cnt = input_similar_part_cnt.to(conf.device); input_ins_one_hot = input_ins_one_hot.to(conf.device)
    input_box_size = input_box_size.to(conf.device)
    batch_size = input_img.shape[0]
    num_point = input_pts.shape[1]

    # forward through the network
    pred_masks = network(input_img - 0.5, input_pts, input_ins_one_hot, input_total_part_cnt)
    # perform matching and calculate masks 
    mask_loss_per_data = []; t = 0;
    matched_pred_mask_all = torch.zeros(batch_size, 224, 224); matched_gt_mask_all = torch.zeros(batch_size, 224, 224) 
    for i in range(len(input_total_part_cnt)):
        total_cnt = input_total_part_cnt[i]
        matched_gt_ids, matched_pred_ids = network.linear_assignment(gt_mask[i], pred_masks[i][:-1, :,:], input_similar_part_cnt[t:t+total_cnt])
        

        # select the matched data
        matched_pred_mask = pred_masks[i][matched_pred_ids]
        matched_gt_mask = gt_mask[i][matched_gt_ids]

        matched_gt_mask_all[t:t+total_cnt, :, :] = matched_gt_mask
        matched_pred_mask_all[t:t+total_cnt, :, :] = matched_pred_mask

        # for computing mask soft iou loss
        matched_mask_loss = network.get_mask_loss(matched_pred_mask, matched_gt_mask)

        mask_loss_per_data.append(matched_mask_loss.mean())
        t+= total_cnt
    mask_loss_per_data = torch.stack(mask_loss_per_data)
    
    # for each type of loss, compute avg loss per batch
    mask_loss = mask_loss_per_data.mean()

    # compute total loss
    total_loss = mask_loss * conf.loss_weight_mask

    # display information
    data_split = 'train'
    if is_val:
        data_split = 'val'

    with torch.no_grad():
        # log to console
        if log_console:
            utils.printout(conf.flog, \
                f'''{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} '''
                f'''{epoch:>5.0f}/{conf.epochs:<5.0f} '''
                f'''{data_split:^10s} '''
                f'''{batch_ind:>5.0f}/{num_batch:<5.0f} '''
                f'''{100. * (1+batch_ind+num_batch*epoch) / (num_batch*conf.epochs):>9.1f}%      '''
                f'''{lr:>5.2E} '''
                f'''{mask_loss.item():>10.5f}'''
                f'''{total_loss.item():>10.5f}''')
            conf.flog.flush()

        # log to tensorboard
        if log_tb and tb_writer is not None:
            tb_writer.add_scalar('mask_loss', mask_loss.item(), step)
            tb_writer.add_scalar('total_loss', total_loss.item(), step)
            tb_writer.add_scalar('lr', lr, step)

        # gen visu
        if is_val and (not conf.no_visu) and epoch % conf.num_epoch_every_visu == 0:
            visu_dir = os.path.join(conf.exp_dir, 'val_visu')
            out_dir = os.path.join(visu_dir, 'epoch-%04d' % epoch)
            input_img_dir = os.path.join(out_dir, 'input_img')
            input_pts_dir = os.path.join(out_dir, 'input_pts')
            gt_mask_dir = os.path.join(out_dir, 'gt_mask')
            pred_mask_dir = os.path.join(out_dir, 'pred_mask')
            info_dir = os.path.join(out_dir, 'info')

            if batch_ind == 0:
                # create folders
                os.mkdir(out_dir)
                os.mkdir(input_img_dir)
                os.mkdir(input_pts_dir)
                os.mkdir(gt_mask_dir)
                os.mkdir(pred_mask_dir)
                os.mkdir(info_dir)

            if batch_ind < conf.num_batch_every_visu:
                utils.printout(conf.flog, 'Visualizing ...')

                t = 0
                for i in range(batch_size):
                    fn = 'data-%03d.png' % (batch_ind * batch_size + i)

                    cur_input_img = (input_img[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    Image.fromarray(cur_input_img).save(os.path.join(input_img_dir, fn))
                    cur_input_pts = input_pts[i].cpu().numpy()
                    render_utils.render_pts(os.path.join(BASE_DIR, input_pts_dir, fn), cur_input_pts, blender_fn='object_centered.blend')
                    cur_gt_mask = (matched_gt_mask_all[i].cpu().numpy() > 0.5).astype(np.uint8) * 255
                    Image.fromarray(cur_gt_mask).save(os.path.join(gt_mask_dir, fn))
                    cur_pred_mask = (matched_pred_mask_all[i].cpu().numpy() > 0.5).astype(np.uint8) * 255
                    Image.fromarray(cur_pred_mask).save(os.path.join(pred_mask_dir, fn))
                
            if batch_ind == conf.num_batch_every_visu - 1:
                # visu html
                utils.printout(conf.flog, 'Generating html visualization ...')
                sublist = 'input_img,input_pts,gt_mask,pred_mask,info'
                cmd = 'cd %s && python %s . 10 htmls %s %s > /dev/null' % (out_dir, os.path.join(BASE_DIR, '../utils/gen_html_hierachy_local.py'), sublist, sublist)
                call(cmd, shell=True)
                utils.printout(conf.flog, 'DONE')

    return total_loss
예제 #47
0
    def _parse_line(self, line):
        if not line:
            if self.commit is not None and self.state == SVNParser.COMMIT \
            or self.state == SVNParser.FILES:
                self.state = SVNParser.MESSAGE
            elif self.state == SVNParser.MESSAGE:
                self.__append_message_line()

            return

        # Message
        if self.state == SVNParser.MESSAGE and self.msg_lines > 0:
            self.__append_message_line(line)

            return

        # Invalid commit. Some svn repos like asterisk have commits like this:
        # r176840 | (no author) | (no date) | 1 line
        # without any canged path, so I think we can just ignore them
        if self.patterns['invalid'].match(line):
            printdbg("SVN Parser: skipping invalid commit: %s", (line, ))
            self.state = SVNParser.COMMIT
            self.commit = None
            return

        # Separator
        if self.patterns['separator'].match(line):
            if self.commit is None or self.state == SVNParser.COMMIT:
                return
            elif self.state == SVNParser.MESSAGE \
            or self.state == SVNParser.FILES:
                # We can go directly from FILES to COMMIT
                # when there is an empty log message
                if self.msg_lines > 0:
                    printout("Warning (%d): parsing svn log, missing " + \
                             "lines in commit message!", (self.n_line,))

                self.__convert_commit_actions(self.commit)
                self.handler.commit(self.commit)
                self.state = SVNParser.COMMIT
                self.commit = None
                self.msg_lines = 0
            else:
                printout("Warning (%d): parsing svn log, unexpected separator",
                         (self.n_line, ))

            return

        # Commit
        match = self.patterns['commit'].match(line)
        if match and self.state == SVNParser.COMMIT:
            commit = Commit()
            commit.revision = match.group(1)

            commit.committer = Person()
            commit.committer.name = match.group(2)

            commit.date = datetime.datetime(int(match.group(3)),
                                            int(match.group(4)),
                                            int(match.group(5)),
                                            int(match.group(6)),
                                            int(match.group(7)),
                                            int(match.group(8)))
            self.msg_lines = int(match.group(10))
            self.commit = commit
            self.handler.committer(commit.committer)

            return
        elif match and self.state == SVNParser.MESSAGE:
            # It seems a piece of a log message has been copied as
            # part of the commit message
            self.commit.message += line + '\n'
            return
        elif match and self.state != SVNParser.COMMIT:
            printout("Warning (%d): parsing svn log, unexpected line %s",
                     (self.n_line, line))
            return

        # Files
        if self.state == SVNParser.COMMIT:
            if self.patterns['paths'].match(line):
                self.state = SVNParser.FILES
            else:
                printout("Warning(%d): parsing svn log, unexpected line %s",
                         (self.n_line, line))

            return

        # File moved/copied/replaced
        match = self.patterns['file-moved'].match(line)
        if match:
            if self.state != SVNParser.FILES:
                printout("Warning (%d): parsing svn log, unexpected line %s",
                         (self.n_line, line))
                return

            action = Action()
            action.type = match.group(1)
            action.f1 = match.group(2)
            action.f2 = match.group(3)
            action.rev = match.group(4)

            action.branch_f1 = self.__guess_branch_from_path(action.f1)
            action.branch_f2 = self.__guess_branch_from_path(action.f2)

            self.commit.actions.append(action)
            self.handler.file(action.f1)

            return

        # File
        match = self.patterns['file'].match(line)
        if match:
            if self.state != SVNParser.FILES:
                printout("Warning (%d): parsing svn log, unexpected line %s",
                         (self.n_line, line))
                return

            path = match.group(2)

            if path != '/':
                # path == '/' is probably a properties change in /
                # not interesting for us, ignoring

                action = Action()
                action.type = match.group(1)
                action.f1 = path

                action.branch_f1 = self.__guess_branch_from_path(path)

                self.commit.actions.append(action)
                self.handler.file(path)

            return