Exemplo n.º 1
0
    def parse_hw_cfg(self, cfg_name, cfg_path):
        self.hw_configurations[cfg_name] = {}
        self.hw_configurations[cfg_name]["path"] = cfg_path

        cfg = utils.load_status(cfg_path)
        for k, v in cfg.items():
            self.hw_configurations[cfg_name][k] = v
Exemplo n.º 2
0
def segmentOneText(infile, outfile, reportfile, fast):
    infilestatuspath = infile + config.getStatusPostfix()
    infilestatus = utils.load_status(infilestatuspath)
    if utils.check_epoch(infilestatus, 'Segment'):
        return

    #begin processing
    if fast:
        cmdline = ['../utils/segment/spseg', \
                       '-o', outfile, infile]
    else:
        cmdline = ['../utils/segment/ngseg', \
                       '-o', outfile, infile]

    subprocess = Popen(cmdline, shell=False, stderr=PIPE, \
                           close_fds=True)

    lines = subprocess.stderr.readlines()
    if lines:
        print('found error report')
        with open(reportfile, 'wb') as f:
            f.writelines(lines)

    os.waitpid(subprocess.pid, 0)
    #end processing

    utils.sign_epoch(infilestatus, 'Segment')
    utils.store_status(infilestatuspath, infilestatus)
Exemplo n.º 3
0
def mergeOneText(infile, outfile, reportfile):
    infilestatuspath = infile + config.getStatusPostfix()
    infilestatus = utils.load_status(infilestatuspath)
    if not utils.check_epoch(infilestatus, 'Segment'):
        raise utils.EpochError('Please segment first.\n')
    if utils.check_epoch(infilestatus, 'MergeSequence'):
        return

    infile = infile + config.getSegmentPostfix()

    #begin processing
    cmdline = ['../utils/segment/mergeseq', \
                   '-o', outfile, infile]

    subprocess = Popen(cmdline, shell=False, stderr=PIPE, \
                           close_fds=True)

    lines = subprocess.stderr.readlines()
    if lines:
        print('found error report')
        with open(reportfile, 'wb') as f:
            f.writelines(lines)

    os.waitpid(subprocess.pid, 0)
    #end processing

    utils.sign_epoch(infilestatus, 'MergeSequence')
    utils.store_status(infilestatuspath, infilestatus)
Exemplo n.º 4
0
def generateOneText(infile, modelfile, reportfile):
    infilestatuspath = infile + config.getStatusPostfix()
    infilestatus = utils.load_status(infilestatuspath)
    if not utils.check_epoch(infilestatus, 'MergeSequence'):
        raise utils.EpochError('Please mergeseq first.\n')
    if utils.check_epoch(infilestatus, 'Generate'):
        return False

    #begin processing
    cmdline = ['../utils/training/gen_k_mixture_model', \
                   '--maximum-occurs-allowed', \
                   str(config.getMaximumOccursAllowed()), \
                   '--maximum-increase-rates-allowed', \
                   str(config.getMaximumIncreaseRatesAllowed()), \
                   '--k-mixture-model-file', \
                   modelfile, infile + \
                   config.getMergedPostfix()]
    subprocess = Popen(cmdline, shell=False, stderr=PIPE, \
                           close_fds=True)

    lines = subprocess.stderr.readlines()
    if lines:
        print('found error report')
        with open(reportfile, 'ab') as f:
            f.writelines(lines)

    (pid, status) = os.waitpid(subprocess.pid, 0)
    if status != 0:
        sys.exit('gen_k_mixture_model encounters error.')
    #end processing

    utils.sign_epoch(infilestatus, 'Generate')
    utils.store_status(infilestatuspath, infilestatus)
    return True
Exemplo n.º 5
0
def sortModels(indexname, sortedindexname):
    sortedindexfilestatuspath = sortedindexname + config.getStatusPostfix()
    sortedindexfilestatus = utils.load_status(sortedindexfilestatuspath)
    if utils.check_epoch(sortedindexfilestatus, 'Estimate'):
        return

    #begin processing
    records = []
    indexfile = open(indexname, 'r')
    for line in indexfile.readlines():
        #remove the trailing '\n'
        line = line.rstrip(os.linesep)
        (subdir, modelname, score) = line.split('#', 2)
        score = float(score)
        records.append((subdir, modelname, score))
    indexfile.close()

    records.sort(key=itemgetter(2), reverse=True)

    sortedindexfile = open(sortedindexname, 'w')
    for record in records:
        (subdir, modelname, score) = record
        line = subdir + '#' + modelname + '#' + str(score)
        sortedindexfile.writelines([line, os.linesep])
    sortedindexfile.close()
    #end processing

    utils.sign_epoch(sortedindexfilestatus, 'Estimate')
    utils.store_status(sortedindexfilestatuspath, sortedindexfilestatus)
Exemplo n.º 6
0
def handleOneIndex(indexpath, subdir, indexname, fast):
    print(indexpath, subdir, indexname)

    indexstatuspath = indexpath + config.getStatusPostfix()
    indexstatus = utils.load_status(indexstatuspath)
    if not utils.check_epoch(indexstatus, 'Prepare'):
        raise utils.EpochError('Please prepare first.\n')
    if utils.check_epoch(indexstatus, 'Populate'):
        return

    workdir = config.getWordRecognizerDir() + os.sep + \
        subdir + os.sep + indexname
    print(workdir)

    shmdir = config.getInMemoryFileSystem()

    for i in range(1, N + 1):
        if fast:
            #copy file
            filename = config.getNgramFileName(i)
            filepath = workdir + os.sep + filename
            shmfilepath = shmdir + os.sep + filename
            utils.copyfile(filepath, shmfilepath)
            handleOnePass(indexpath, shmdir, i)
            pruneNgramTable(indexpath, shmdir, i)
            utils.copyfile(shmfilepath, filepath)
            os.unlink(shmfilepath)
        else:
            handleOnePass(indexpath, workdir, i)
            pruneNgramTable(indexpath, workdir, i)

    #sign epoch
    utils.sign_epoch(indexstatus, 'Populate')
    utils.store_status(indexstatuspath, indexstatus)
Exemplo n.º 7
0
def mergeOneText(infile, outfile, reportfile):
    infilestatuspath = infile + config.getStatusPostfix()
    infilestatus = utils.load_status(infilestatuspath)
    if not utils.check_epoch(infilestatus, 'Segment'):
        raise utils.EpochError('Please segment first.\n')
    if utils.check_epoch(infilestatus, 'MergeSequence'):
        return

    infile = infile + config.getSegmentPostfix()

    #begin processing
    cmdline = ['../utils/segment/mergeseq', \
                   '-o', outfile, infile]

    subprocess = Popen(cmdline, shell=False, stderr=PIPE, \
                           close_fds=True)

    lines = subprocess.stderr.readlines()
    if lines:
        print('found error report')
        with open(reportfile, 'wb') as f:
            f.writelines(lines)

    os.waitpid(subprocess.pid, 0)
    #end processing

    utils.sign_epoch(infilestatus, 'MergeSequence')
    utils.store_status(infilestatuspath, infilestatus)
Exemplo n.º 8
0
def handleOneIndex(indexpath):
    indexstatuspath = indexpath + config.getStatusPostfix()
    indexstatus = utils.load_status(indexstatuspath)
    if not utils.check_epoch(indexstatus, 'Segment'):
        raise utils.EpochError('Please segment first.\n')
    if utils.check_epoch(indexstatus, 'MergeSequence'):
        return

    #begin processing
    indexfile = open(indexpath, 'r')
    for oneline in indexfile.readlines():
        #remove tailing '\n'
        oneline = oneline.rstrip(os.linesep)
        (title, textpath) = oneline.split('#')

        infile = config.getTextDir() + textpath
        outfile = config.getTextDir() + textpath + config.getMergedPostfix()
        reportfile = config.getTextDir() + textpath + \
            config.getMergedReportPostfix()

        print("Processing " + title + '#' + textpath)
        mergeOneText(infile, outfile, reportfile)
        print("Processed " + title + '#' + textpath)

    indexfile.close()
    #end processing

    utils.sign_epoch(indexstatus, 'MergeSequence')
    utils.store_status(indexstatuspath, indexstatus)
Exemplo n.º 9
0
def segmentOneText(infile, outfile, reportfile, fast):
    infilestatuspath = infile + config.getStatusPostfix()
    infilestatus = utils.load_status(infilestatuspath)
    if utils.check_epoch(infilestatus, 'Segment'):
        return

    #begin processing
    if fast:
        cmdline = ['../utils/segment/spseg', \
                       '-o', outfile, infile]
    else:
        cmdline = ['../utils/segment/ngseg', \
                       '-o', outfile, infile]

    subprocess = Popen(cmdline, shell=False, stderr=PIPE, \
                           close_fds=True)

    lines = subprocess.stderr.readlines()
    if lines:
        print('found error report')
        with open(reportfile, 'wb') as f:
            f.writelines(lines)

    os.waitpid(subprocess.pid, 0)
    #end processing

    utils.sign_epoch(infilestatus, 'Segment')
    utils.store_status(infilestatuspath, infilestatus)
Exemplo n.º 10
0
def handleOneIndex(indexpath, subdir, indexname, fast):
    print(indexpath, subdir, indexname)

    indexstatuspath = indexpath + config.getStatusPostfix()
    indexstatus = utils.load_status(indexstatuspath)
    if not utils.check_epoch(indexstatus, 'Prepare'):
        raise utils.EpochError('Please prepare first.\n')
    if utils.check_epoch(indexstatus, 'Populate'):
        return

    workdir = config.getWordRecognizerDir() + os.sep + \
        subdir + os.sep + indexname
    print(workdir)

    shmdir = config.getInMemoryFileSystem()

    for i in range(1, N + 1):
        if fast:
            #copy file
            filename = config.getNgramFileName(i)
            filepath = workdir + os.sep + filename
            shmfilepath = shmdir + os.sep + filename
            utils.copyfile(filepath, shmfilepath)
            handleOnePass(indexpath, shmdir, i)
            pruneNgramTable(indexpath, shmdir, i)
            utils.copyfile(shmfilepath, filepath)
            os.unlink(shmfilepath)
        else:
            handleOnePass(indexpath, workdir, i)
            pruneNgramTable(indexpath, workdir, i)

    #sign epoch
    utils.sign_epoch(indexstatus, 'Populate')
    utils.store_status(indexstatuspath, indexstatus)
Exemplo n.º 11
0
def handleOneIndex(indexpath):
    indexstatuspath = indexpath + config.getStatusPostfix()
    indexstatus = utils.load_status(indexstatuspath)
    if not utils.check_epoch(indexstatus, 'Segment'):
        raise utils.EpochError('Please segment first.\n')
    if utils.check_epoch(indexstatus, 'MergeSequence'):
        return

    #begin processing
    indexfile = open(indexpath, 'r')
    for oneline in indexfile.readlines():
        #remove tailing '\n'
        oneline = oneline.rstrip(os.linesep)
        (title, textpath) = oneline.split('#')

        infile = config.getTextDir() + textpath
        outfile = config.getTextDir() + textpath + config.getMergedPostfix()
        reportfile = config.getTextDir() + textpath + \
            config.getMergedReportPostfix()

        print("Processing " + title + '#' + textpath)
        mergeOneText(infile, outfile, reportfile)
        print("Processed " + title + '#' + textpath)

    indexfile.close()
    #end processing

    utils.sign_epoch(indexstatus, 'MergeSequence')
    utils.store_status(indexstatuspath, indexstatus)
Exemplo n.º 12
0
def sortModels(indexname, sortedindexname):
    sortedindexfilestatuspath = sortedindexname + config.getStatusPostfix()
    sortedindexfilestatus = utils.load_status(sortedindexfilestatuspath)
    if utils.check_epoch(sortedindexfilestatus, 'Estimate'):
        return

    #begin processing
    records = []
    indexfile = open(indexname, 'r')
    for line in indexfile.readlines():
        #remove the trailing '\n'
        line = line.rstrip(os.linesep)
        (subdir, modelname, score) = line.split('#', 2)
        score = float(score)
        records.append((subdir, modelname, score))
    indexfile.close()

    records.sort(key=itemgetter(2), reverse=True)

    sortedindexfile = open(sortedindexname, 'w')
    for record in records:
        (subdir, modelname, score) = record
        line = subdir + '#' + modelname + '#' + str(score)
        sortedindexfile.writelines([line, os.linesep])
    sortedindexfile.close()
    #end processing

    utils.sign_epoch(sortedindexfilestatus, 'Estimate')
    utils.store_status(sortedindexfilestatuspath, sortedindexfilestatus)
Exemplo n.º 13
0
def handleOneDocument(infile, cur, length):
    print(infile, length)

    infilestatuspath = infile + config.getStatusPostfix()
    infilestatus = utils.load_status(infilestatuspath)
    if not utils.check_epoch(infilestatus, 'Segment'):
        raise utils.EpochError('Please segment first.\n')
    if utils.check_epoch(infilestatus, 'Populate'):
        return False

    sep = config.getWordSep()

    #train
    docfile = open(infile + config.getSegmentPostfix(), 'r')
    words = []

    for oneline in docfile.readlines():
        oneline = oneline.rstrip(os.linesep)

        if len(oneline) == 0:
            continue

        (token, word) = oneline.split(" ", 1)
        token = int(token)

        if 0 == token:
            words = []
        else:
            words.append(word)

        if len(words) < length:
            continue

        if len(words) > length:
            words.pop(0)

        assert len(words) == length

        #do sqlite training
        words_str = sep + sep.join(words) + sep
        #print(words_str)

        rowcount = cur.execute(UPDATE_NGRAM_DML, (words_str, )).rowcount
        #print(rowcount)
        assert rowcount <= 1

        if 0 == rowcount:
            cur.execute(INSERT_NGRAM_DML, (words_str, ))

    docfile.close()

    #sign epoch only after last pass
    if N == length:
        utils.sign_epoch(infilestatus, 'Populate')
        utils.store_status(infilestatuspath, infilestatus)

    return True
Exemplo n.º 14
0
def handleOneDocument(infile, cur, length):
    print(infile, length)

    infilestatuspath = infile + config.getStatusPostfix()
    infilestatus = utils.load_status(infilestatuspath)
    if not utils.check_epoch(infilestatus, 'Segment'):
        raise utils.EpochError('Please segment first.\n')
    if utils.check_epoch(infilestatus, 'Populate'):
        return False

    sep = config.getWordSep()

    #train
    docfile = open(infile + config.getSegmentPostfix(), 'r')
    words = []

    for oneline in docfile.readlines():
        oneline = oneline.rstrip(os.linesep)

        if len(oneline) == 0:
            continue

        (token, word) = oneline.split(" ", 1)
        token = int(token)

        if 0 == token:
            words = []
        else:
            words.append(word)

        if len(words) < length:
            continue

        if len(words) > length:
            words.pop(0)

        assert len(words) == length

        #do sqlite training
        words_str = sep + sep.join(words) + sep
        #print(words_str)

        rowcount = cur.execute(UPDATE_NGRAM_DML, (words_str,)).rowcount
        #print(rowcount)
        assert rowcount <= 1

        if 0 == rowcount:
            cur.execute(INSERT_NGRAM_DML, (words_str,))

    docfile.close()

    #sign epoch only after last pass
    if N == length:
        utils.sign_epoch(infilestatus, 'Populate')
        utils.store_status(infilestatuspath, infilestatus)

    return True
Exemplo n.º 15
0
def gatherModels(path, indexname):
    indexfilestatuspath = indexname + config.getStatusPostfix()
    indexfilestatus = utils.load_status(indexfilestatuspath)
    if utils.check_epoch(indexfilestatuspath, 'Estimate'):
        return

    #begin processing
    indexfile = open(indexname, "w")
    for root, dirs, files in os.walk(path, topdown=True, onerror=handleError):
        for onefile in files:
            filepath = os.path.join(root, onefile)
            if onefile.endswith(config.getModelPostfix()):
                #append one record to index file
                subdir = os.path.relpath(root, path)
                statusfilepath = filepath + config.getStatusPostfix()
                status = utils.load_status(statusfilepath)
                if not (utils.check_epoch(status, 'Estimate') and \
                        'EstimateScore' in status):
                    raise utils.EpochError('Unknown Error:\n' + \
                                               'Try re-run estimate.\n')
                avg_lambda = status['EstimateScore']
                line = subdir + '#' + onefile + '#' + str(avg_lambda)
                indexfile.writelines([line, os.linesep])
                #record written
            elif onefile.endswith(config.getStatusPostfix()):
                pass
            elif onefile.endswith(config.getIndexPostfix()):
                pass
            elif onefile.endswith(config.getReportPostfix()):
                pass
            else:
                print('Unexpected file:' + filepath)
    indexfile.close()
    #end processing

    utils.sign_epoch(indexfilestatus, 'Estimate')
    utils.store_status(indexfilestatuspath, indexfilestatus)
Exemplo n.º 16
0
def gatherModels(path, indexname):
    indexfilestatuspath = indexname + config.getStatusPostfix()
    indexfilestatus = utils.load_status(indexfilestatuspath)
    if utils.check_epoch(indexfilestatuspath, 'Estimate'):
        return

    #begin processing
    indexfile = open(indexname, "w")
    for root, dirs, files in os.walk(path, topdown=True, onerror=handleError):
        for onefile in files:
            filepath = os.path.join(root, onefile)
            if onefile.endswith(config.getModelPostfix()):
                #append one record to index file
                subdir = os.path.relpath(root, path)
                statusfilepath = filepath + config.getStatusPostfix()
                status = utils.load_status(statusfilepath)
                if not (utils.check_epoch(status, 'Estimate') and \
                        'EstimateScore' in status):
                    raise utils.EpochError('Unknown Error:\n' + \
                                               'Try re-run estimate.\n')
                avg_lambda = status['EstimateScore']
                line = subdir + '#' + onefile + '#' + str(avg_lambda)
                indexfile.writelines([line, os.linesep])
                #record written
            elif onefile.endswith(config.getStatusPostfix()):
                pass
            elif onefile.endswith(config.getIndexPostfix()):
                pass
            elif onefile.endswith(config.getReportPostfix()):
                pass
            else:
                print('Unexpected file:' + filepath)
    indexfile.close()
    #end processing

    utils.sign_epoch(indexfilestatus, 'Estimate')
    utils.store_status(indexfilestatuspath, indexfilestatus)
Exemplo n.º 17
0
def handleOneIndex(indexpath, subdir, indexname):
    print(indexpath, subdir, indexname)

    indexstatuspath = indexpath + config.getStatusPostfix()
    indexstatus = utils.load_status(indexstatuspath)
    if not utils.check_epoch(indexstatus, 'NewWord'):
        raise utils.EpochError('Please new word first.\n')
    if utils.check_epoch(indexstatus, 'MarkPinyin'):
        return

    workdir = config.getWordRecognizerDir() + os.sep + \
        subdir + os.sep + indexname
    print(workdir)

    markPinyins(workdir)

    #sign epoch
    utils.sign_epoch(indexstatus, 'MarkPinyin')
    utils.store_status(indexstatuspath, indexstatus)
Exemplo n.º 18
0
def handleOneIndex(indexpath, subdir, indexname):
    print(indexpath, subdir, indexname)

    indexstatuspath = indexpath + config.getStatusPostfix()
    indexstatus = utils.load_status(indexstatuspath)
    if not utils.check_epoch(indexstatus, 'NewWord'):
        raise utils.EpochError('Please new word first.\n')
    if utils.check_epoch(indexstatus, 'MarkPinyin'):
        return

    workdir = config.getWordRecognizerDir() + os.sep + \
        subdir + os.sep + indexname
    print(workdir)

    markPinyins(workdir)

    #sign epoch
    utils.sign_epoch(indexstatus, 'MarkPinyin')
    utils.store_status(indexstatuspath, indexstatus)
Exemplo n.º 19
0
def handleOneModel(modelfile, reportfile):
    modelfilestatuspath = modelfile + config.getStatusPostfix()
    modelfilestatus = utils.load_status(modelfilestatuspath)
    if not utils.check_epoch(modelfilestatus, 'Generate'):
        raise utils.EpochError('Please generate first.\n')
    if utils.check_epoch(modelfilestatus, 'Estimate'):
        return

    reporthandle = open(reportfile, 'wb')

    result_line_prefix = "average lambda:"
    avg_lambda = 0.

    #begin processing
    cmdline = ['../utils/training/estimate_k_mixture_model', \
                   '--deleted-bigram-file', \
                   config.getEstimatesModel(), \
                   '--bigram-file', \
                   modelfile]

    subprocess = Popen(cmdline, shell=False, stdout=PIPE, \
                           close_fds=True)

    for line in subprocess.stdout.readlines():
        reporthandle.writelines([line])
        #remove trailing '\n'
        line = line.decode('utf-8')
        line = line.rstrip(os.linesep)
        if line.startswith(result_line_prefix):
            avg_lambda = float(line[len(result_line_prefix):])

    reporthandle.close()

    (pid, status) = os.waitpid(subprocess.pid, 0)
    if status != 0:
        sys.exit('estimate k mixture model returns error.')
    #end processing

    print('average lambda:', avg_lambda)
    modelfilestatus['EstimateScore'] = avg_lambda
    utils.sign_epoch(modelfilestatus, 'Estimate')
    utils.store_status(modelfilestatuspath, modelfilestatus)
Exemplo n.º 20
0
def handleOneModel(modelfile, reportfile):
    modelfilestatuspath = modelfile + config.getStatusPostfix()
    modelfilestatus = utils.load_status(modelfilestatuspath)
    if not utils.check_epoch(modelfilestatus, 'Generate'):
        raise utils.EpochError('Please generate first.\n')
    if utils.check_epoch(modelfilestatus, 'Estimate'):
        return

    reporthandle = open(reportfile, 'wb')

    result_line_prefix = "average lambda:"
    avg_lambda = 0.

    #begin processing
    cmdline = ['../utils/training/estimate_k_mixture_model', \
                   '--deleted-bigram-file', \
                   config.getEstimatesModel(), \
                   '--bigram-file', \
                   modelfile]

    subprocess = Popen(cmdline, shell=False, stdout=PIPE, \
                           close_fds=True)

    for line in subprocess.stdout.readlines():
        reporthandle.writelines([line])
        #remove trailing '\n'
        line = line.decode('utf-8')
        line = line.rstrip(os.linesep)
        if line.startswith(result_line_prefix):
            avg_lambda = float(line[len(result_line_prefix):])

    reporthandle.close()

    (pid, status) = os.waitpid(subprocess.pid, 0)
    if status != 0:
        sys.exit('estimate k mixture model returns error.')
    #end processing

    print('average lambda:', avg_lambda)
    modelfilestatus['EstimateScore'] = avg_lambda
    utils.sign_epoch(modelfilestatus, 'Estimate')
    utils.store_status(modelfilestatuspath, modelfilestatus)
Exemplo n.º 21
0
def mergeOneModel(mergedmodel, onemodel, score):

    onemodelstatuspath = onemodel + config.getStatusPostfix()
    onemodelstatus = utils.load_status(onemodelstatuspath)
    if not utils.check_epoch(onemodelstatus, 'Estimate'):
        raise utils.Epoch('Please estimate first.\n')
    if score != onemodelstatus['EstimateScore']:
        raise AssertionError('estimate scores mis-match.\n')

    #begin processing
    cmdline = ['../utils/training/merge_k_mixture_model', \
                   '--result-file', \
                   mergedmodel, \
                   onemodel]

    subprocess = Popen(cmdline, shell=False, close_fds=True)

    (pid, status) = os.waitpid(subprocess.pid, 0)
    if status != 0:
        sys.exit('Corrupted model found when merging:' + onemodel)
Exemplo n.º 22
0
def mergeOneModel(mergedmodel, onemodel, score):

    onemodelstatuspath = onemodel + config.getStatusPostfix()
    onemodelstatus = utils.load_status(onemodelstatuspath)
    if not utils.check_epoch(onemodelstatus, 'Estimate'):
        raise utils.Epoch('Please estimate first.\n')
    if score != onemodelstatus['EstimateScore']:
        raise AssertionError('estimate scores mis-match.\n')

    #begin processing
    cmdline = ['../utils/training/merge_k_mixture_model', \
                   '--result-file', \
                   mergedmodel, \
                   onemodel]

    subprocess = Popen(cmdline, shell=False, close_fds=True)

    (pid, status) = os.waitpid(subprocess.pid, 0)
    if status != 0:
        sys.exit('Corrupted model found when merging:' + onemodel)
Exemplo n.º 23
0
def handleOneIndex(indexpath, subdir, indexname):
    print(indexpath, subdir, indexname)

    indexstatuspath = indexpath + config.getStatusPostfix()
    indexstatus = utils.load_status(indexstatuspath)
    if not utils.check_epoch(indexstatus, 'Segment'):
        raise utils.EpochError('Please segment first.\n')
    if utils.check_epoch(indexstatus, 'Prepare'):
        return

    #create directory
    onedir = config.getWordRecognizerDir() + os.sep + \
        subdir + os.sep + indexname
    os.path.exists(onedir) or os.makedirs(onedir)

    #create sqlite databases
    createSqliteDatabases(onedir)

    #sign epoch
    utils.sign_epoch(indexstatus, 'Prepare')
    utils.store_status(indexstatuspath, indexstatus)
Exemplo n.º 24
0
def handleOneIndex(indexpath, subdir, indexname):
    print(indexpath, subdir, indexname)

    indexstatuspath = indexpath + config.getStatusPostfix()
    indexstatus = utils.load_status(indexstatuspath)
    if not utils.check_epoch(indexstatus, 'Segment'):
        raise utils.EpochError('Please segment first.\n')
    if utils.check_epoch(indexstatus, 'Prepare'):
        return

    #create directory
    onedir = config.getWordRecognizerDir() + os.sep + \
        subdir + os.sep + indexname
    os.path.exists(onedir) or os.makedirs(onedir)

    #create sqlite databases
    createSqliteDatabases(onedir)

    #sign epoch
    utils.sign_epoch(indexstatus, 'Prepare')
    utils.store_status(indexstatuspath, indexstatus)
Exemplo n.º 25
0
def handleOneIndex(indexpath, subdir, indexname):
    print(indexpath, subdir, indexname)

    indexstatuspath = indexpath + config.getStatusPostfix()
    indexstatus = utils.load_status(indexstatuspath)
    if not utils.check_epoch(indexstatus, 'PartialWord'):
        raise utils.EpochError('Please partial word first.\n')
    if utils.check_epoch(indexstatus, 'NewWord'):
        return

    workdir = config.getWordRecognizerDir() + os.sep + \
        subdir + os.sep + indexname
    print(workdir)

    createBigramSqlite(workdir)
    populateBigramSqlite(workdir)

    filename = config.getBigramFileName()
    filepath = workdir + os.sep + filename

    conn = sqlite3.connect(filepath)

    prethres = computeThreshold(conn, "prefix")
    indexstatus['NewWordPrefixThreshold'] = prethres
    postthres = computeThreshold(conn, "postfix")
    indexstatus['NewWordPostfixThreshold'] = postthres

    utils.store_status(indexstatuspath, indexstatus)

    filterPartialWord(workdir, conn, prethres, postthres)

    conn.commit()
    if conn:
        conn.close()

    #sign epoch
    utils.sign_epoch(indexstatus, 'NewWord')
    utils.store_status(indexstatuspath, indexstatus)
Exemplo n.º 26
0
def handleOneIndex(indexpath, subdir, indexname):
    print(indexpath, subdir, indexname)

    indexstatuspath = indexpath + config.getStatusPostfix()
    indexstatus = utils.load_status(indexstatuspath)
    if not utils.check_epoch(indexstatus, 'PartialWord'):
        raise utils.EpochError('Please partial word first.\n')
    if utils.check_epoch(indexstatus, 'NewWord'):
        return

    workdir = config.getWordRecognizerDir() + os.sep + \
        subdir + os.sep + indexname
    print(workdir)

    createBigramSqlite(workdir)
    populateBigramSqlite(workdir)

    filename = config.getBigramFileName()
    filepath = workdir + os.sep + filename

    conn = sqlite3.connect(filepath)

    prethres = computeThreshold(conn, "prefix")
    indexstatus['NewWordPrefixThreshold'] = prethres
    postthres = computeThreshold(conn, "postfix")
    indexstatus['NewWordPostfixThreshold'] = postthres

    utils.store_status(indexstatuspath, indexstatus)

    filterPartialWord(workdir, conn, prethres, postthres)

    conn.commit()
    if conn:
        conn.close()

    #sign epoch
    utils.sign_epoch(indexstatus, 'NewWord')
    utils.store_status(indexstatuspath, indexstatus)
Exemplo n.º 27
0
def handleOneIndex(indexpath, subdir, indexname):
    print(indexpath, subdir, indexname)

    indexstatuspath = indexpath + config.getStatusPostfix()
    indexstatus = utils.load_status(indexstatuspath)
    if not utils.check_epoch(indexstatus, 'Populate'):
        raise utils.EpochError('Please populate first.\n')
    if utils.check_epoch(indexstatus, 'PartialWord'):
        return

    workdir = config.getWordRecognizerDir() + os.sep + \
        subdir + os.sep + indexname
    print(workdir)

    threshold = getThreshold(workdir)
    indexstatus['PartialWordThreshold'] = threshold
    utils.store_status(indexstatuspath, indexstatus)

    recognizePartialWord(workdir, threshold)

    #sign epoch
    utils.sign_epoch(indexstatus, 'PartialWord')
    utils.store_status(indexstatuspath, indexstatus)
Exemplo n.º 28
0
    parser.add_argument('--finaldir', action='store', \
                            help='final directory', \
                            default=config.getFinalModelDir())
    parser.add_argument('tryname', action='store', \
                            help='the storage directory')

    args = parser.parse_args()
    print(args)
    tryname = 'try' + args.tryname

    trydir = os.path.join(args.finaldir, tryname)
    if not os.access(trydir, os.F_OK):
        sys.exit(tryname + "doesn't exist.")

    cwdstatuspath = os.path.join(trydir, config.getFinalStatusFileName())
    cwdstatus = utils.load_status(cwdstatuspath)
    if not utils.check_epoch(cwdstatus, 'Prune'):
        raise utils.EpochError('Please tryprune first.')

    if utils.check_epoch(cwdstatus, 'Evaluate'):
        sys.exit('already evaluated.')

    print('checking')
    checkData()

    modelfile = os.path.join(trydir, config.getFinalModelFileName())
    destfile = os.path.join(libpinyin_dir, 'data', \
                                config.getFinalModelFileName())

    utils.copyfile(modelfile, destfile)
Exemplo n.º 29
0
envs = []
for i in range(args.procs):
    env = gym.make(args.env)
    env.seed(args.seed + 10000 * i)
    envs.append(env)

# Define obss preprocessor

preprocess_obss = utils.ObssPreprocessor(save_dir, envs[0].observation_space)

# Define actor-critic model

if utils.model_exists(save_dir):
    acmodel = utils.load_model(save_dir)
    status = utils.load_status(save_dir)
    logger.info("Model successfully loaded\n")
else:
    acmodel = ACModel(preprocess_obss.obs_space, envs[0].action_space,
                      not args.no_instr, not args.no_mem)
    status = {"num_frames": 0, "update": 0}
    logger.info("Model successfully created\n")
logger.info("{}\n".format(acmodel))

if torch.cuda.is_available():
    acmodel.cuda()
logger.info("CUDA available: {}\n".format(torch.cuda.is_available()))

# Define actor-critic algo

if args.algo == "a2c":
Exemplo n.º 30
0
def run(full_args: Namespace) -> None:

    args = full_args.main
    agent_args = full_args.agent
    model_args = full_args.model

    if args.seed == 0:
        args.seed = full_args.run_id + 1
    max_eprews = args.max_eprews

    post_process_args(agent_args)
    post_process_args(model_args)

    model_dir = full_args.cfg_dir
    print(model_dir)

    # ==============================================================================================
    # Set seed for all randomness sources
    utils.seed(args.seed)

    # ==============================================================================================
    # Generate environment

    env = gym.make(args.env)
    env.max_steps = full_args.env_cfg.max_episode_steps
    env.seed(args.seed + 10000 * 0)

    env = gym_wrappers.RecordingBehaviour(env)

    # Define obss preprocessor
    max_image_value = full_args.env_cfg.max_image_value
    normalize_img = full_args.env_cfg.normalize
    obs_space, preprocess_obss = utils.get_obss_preprocessor(
        args.env,
        env.observation_space,
        model_dir,
        max_image_value=max_image_value,
        normalize=normalize_img)
    # ==============================================================================================
    # Load training status
    try:
        status = utils.load_status(model_dir)
    except OSError:
        status = {"num_frames": 0, "update": 0}

    saver = utils.SaveData(model_dir,
                           save_best=args.save_best,
                           save_all=args.save_all)
    model, agent_data, other_data = None, dict(), None
    try:
        # Continue from last point
        model, agent_data, other_data = saver.load_training_data(best=False)
        print("Training data exists & loaded successfully\n")
    except OSError:
        print("Could not load training data\n")

    if torch.cuda.is_available():
        model.cuda()
        device = torch.device("cuda")
    else:
        model.cpu()
        device = torch.device("cpu")

    # ==============================================================================================
    # Test model

    done = False
    model.eval()

    initial_image = None

    if agent_args.name == 'PPORND':
        model = model.policy

    import argparse
    n_cfg = argparse.Namespace()
    viz = visualize_episode.VisualizeEpisode(n_cfg)

    obs = env.reset()
    memory = torch.zeros(1, model.memory_size, device=device)
    while True:
        if done:
            agent_behaviour = env.get_behaviour()
            nr_steps = agent_behaviour['step_count']
            map_shape = np.array((agent_behaviour['full_states'].shape[1],
                                  agent_behaviour['full_states'].shape[2]))
            new_img = viz.draw_single_episode(
                initial_image,
                agent_behaviour['positions'][:nr_steps].astype(np.uint8),
                map_shape,
                agent_behaviour['actions'][:nr_steps].astype(np.uint8))

            cv2.imshow("Map", new_img)
            cv2.waitKey(0)

            obs = env.reset()
            memory = torch.zeros(1, model.memory_size, device=device)

        time.sleep(0.1)
        renderer = env.render()
        if initial_image is None:
            initial_image = renderer.getArray()

        preprocessed_obs = preprocess_obss([obs], device=device)
        if model.recurrent:
            dist, _, memory = model(preprocessed_obs, memory)
        else:
            dist, value = model(preprocessed_obs)

        #action = dist.probs.argmax()
        action = dist.sample()
        obs, reward, done, _ = env.step(action.cpu().numpy())
        if renderer.window is None:
            break
Exemplo n.º 31
0
def run(full_args: Namespace) -> None:
    # import torch.multiprocessing as mp
    # mp.set_start_method('spawn')

    args = full_args.main
    agent_args = full_args.agent
    model_args = full_args.model
    env_args = full_args.env_cfg
    extra_logs = getattr(full_args, "extra_logs", None)

    if args.seed == 0:
        args.seed = full_args.run_id + 1
    max_eprews = args.max_eprews

    post_process_args(agent_args)
    post_process_args(model_args)

    model_dir = getattr(args, "model_dir", full_args.out_dir)
    print(model_dir)

    # ==============================================================================================
    # @ torc_rl repo original

    # Define logger, CSV writer and Tensorboard writer

    logger = utils.get_logger(model_dir)
    csv_file, csv_writer = utils.get_csv_writer(model_dir)
    tb_writer = None
    if args.tb:
        from tensorboardX import SummaryWriter
        tb_writer = SummaryWriter(model_dir)

    # Log command and all script arguments

    logger.info("{}\n".format(" ".join(sys.argv)))
    logger.info("{}\n".format(args))

    # ==============================================================================================
    # Set seed for all randomness sources
    utils.seed(args.seed)

    # ==============================================================================================
    # Generate environments

    envs = []

    # Get environment wrapper
    wrapper_method = getattr(full_args.env_cfg, "wrapper", None)
    if wrapper_method is None:

        def idem(x):
            return x

        env_wrapper = idem
    else:
        env_wrappers = [getattr(environment, w_p) for w_p in wrapper_method]

        def env_wrapp(w_env):
            for wrapper in env_wrappers[::-1]:
                w_env = wrapper(w_env)
            return w_env

        env_wrapper = env_wrapp

    actual_procs = getattr(args, "actual_procs", None)
    master_make_envs = getattr(full_args.env_cfg, "master_make_envs", False)

    if actual_procs:
        # Split envs in chunks
        no_envs = args.procs
        envs, chunk_size = get_envs(full_args,
                                    env_wrapper,
                                    no_envs,
                                    master_make=master_make_envs)
        first_env = envs[0][0]
        print(
            f"NO of envs / proc: {chunk_size}; No of processes {len(envs[1:])} + Master"
        )
    else:
        for i in range(args.procs):
            env = env_wrapper(gym.make(args.env))
            env.max_steps = full_args.env_cfg.max_episode_steps
            env.no_stacked_frames = full_args.env_cfg.no_stacked_frames

            env.seed(args.seed + 10000 * i)
            envs.append(env)
        first_env = envs[0]

    # Generate evaluation envs
    eval_envs = []
    if full_args.env_cfg.no_eval_envs > 0:
        no_envs = full_args.env_cfg.no_eval_envs
        eval_envs, chunk_size = get_envs(full_args,
                                         env_wrapper,
                                         no_envs,
                                         master_make=master_make_envs)

    # Define obss preprocessor
    max_image_value = full_args.env_cfg.max_image_value
    normalize_img = full_args.env_cfg.normalize
    obs_space, preprocess_obss = utils.get_obss_preprocessor(
        args.env,
        first_env.observation_space,
        model_dir,
        max_image_value=max_image_value,
        normalize=normalize_img)

    # ==============================================================================================
    # Load training status
    try:
        status = utils.load_status(model_dir)
    except OSError:
        status = {"num_frames": 0, "update": 0}

    saver = utils.SaveData(model_dir,
                           save_best=args.save_best,
                           save_all=args.save_all)
    model, agent_data, other_data = None, dict(), None
    try:
        # Continue from last point
        model, agent_data, other_data = saver.load_training_data(best=False)
        logger.info("Training data exists & loaded successfully\n")
    except OSError:
        logger.info("Could not load training data\n")

    # ==============================================================================================
    # Load Model

    if model is None:
        model = get_model(model_args,
                          obs_space,
                          first_env.action_space,
                          use_memory=model_args.use_memory,
                          no_stacked_frames=env_args.no_stacked_frames)
        logger.info(f"Model [{model_args.name}] successfully created\n")

        # Print Model info
        logger.info("{}\n".format(model))

    if torch.cuda.is_available():
        model.cuda()
    logger.info("CUDA available: {}\n".format(torch.cuda.is_available()))

    # ==============================================================================================
    # Load Agent

    algo = get_agent(full_args.agent,
                     envs,
                     model,
                     agent_data,
                     preprocess_obss=preprocess_obss,
                     reshape_reward=None,
                     eval_envs=eval_envs)

    has_evaluator = hasattr(algo,
                            "evaluate") and full_args.env_cfg.no_eval_envs > 0

    # ==============================================================================================
    # Train model

    crt_eprew = 0
    if "eprew" in other_data:
        crt_eprew = other_data["eprew"]
    num_frames = status["num_frames"]
    total_start_time = time.time()
    update = status["update"]
    update_start_time = time.time()

    while num_frames < args.frames:
        # Update model parameters

        logs = algo.update_parameters()

        num_frames += logs["num_frames"]
        update += 1

        if has_evaluator:
            if update % args.eval_interval == 0:
                algo.evaluate()

        prev_start_time = update_start_time
        update_start_time = time.time()

        # Print logs
        if update % args.log_interval == 0:
            fps = logs["num_frames"] / (update_start_time - prev_start_time)
            duration = int(time.time() - total_start_time)
            return_per_episode = utils.synthesize(logs["return_per_episode"])
            rreturn_per_episode = utils.synthesize(
                logs["reshaped_return_per_episode"])
            num_frames_per_episode = utils.synthesize(
                logs["num_frames_per_episode"])

            header = ["update", "frames", "FPS", "duration"]
            data = [update, num_frames, fps, duration]
            header += ["rreturn_" + key for key in rreturn_per_episode.keys()]
            data += rreturn_per_episode.values()
            header += [
                "num_frames_" + key for key in num_frames_per_episode.keys()
            ]
            data += num_frames_per_episode.values()
            header += ["entropy", "value", "policy_loss", "value_loss"]
            data += [
                logs["entropy"], logs["value"], logs["policy_loss"],
                logs["value_loss"]
            ]
            header += ["grad_norm"]
            data += [logs["grad_norm"]]

            # add log fields that are not in the standard log format (for example value_int)
            extra_fields = extra_log_fields(header, list(logs.keys()))
            header.extend(extra_fields)
            data += [logs[field] for field in extra_fields]

            # print to stdout the standard log fields + fields required in config
            keys_format, printable_data = print_keys(header, data, extra_logs)
            logger.info(keys_format.format(*printable_data))

            header += ["return_" + key for key in return_per_episode.keys()]
            data += return_per_episode.values()

            if status["num_frames"] == 0:
                csv_writer.writerow(header)
            csv_writer.writerow(data)
            csv_file.flush()

            if args.tb:
                for field, value in zip(header, data):
                    tb_writer.add_scalar(field, value, num_frames)

            status = {"num_frames": num_frames, "update": update}

            crt_eprew = list(rreturn_per_episode.values())[0]

        # -- Save vocabulary and model

        if args.save_interval > 0 and update % args.save_interval == 0:
            # preprocess_obss.vocab.save()

            saver.save_training_data(model, algo.get_save_data(), crt_eprew)

            logger.info("Model successfully saved")

            utils.save_status(status, model_dir)

        if crt_eprew > max_eprews != 0:
            print("Reached max return 0.93")
            exit()
Exemplo n.º 32
0
# Generate environments

envs = []
for i in range(args.procs):
    env = gym.make(args.env)
    env.seed(args.seed + 10000*i)
    envs.append(env)

# Define obss preprocessor

obs_space, preprocess_obss = utils.get_obss_preprocessor(args.env, envs[0].observation_space, model_dir)

# Load training status

try:
    status = utils.load_status(model_dir)
except OSError:
    status = {"num_frames": 0, "update": 0}

# Define actor-critic model

try:
    acmodel = utils.load_model(model_dir)
    logger.info("Model successfully loaded\n")
except OSError:
    acmodel = ACModel(obs_space, envs[0].action_space, args.mem, args.text)
    logger.info("Model successfully created\n")
logger.info("{}\n".format(acmodel))

if torch.cuda.is_available():
    acmodel.cuda()
Exemplo n.º 33
0
    parser.add_argument('--finaldir', action='store', \
                            help='final directory', \
                            default=config.getFinalModelDir())
    parser.add_argument('tryname', action='store', \
                            help='the storage directory')

    args = parser.parse_args()
    print(args)
    tryname = 'try' + args.tryname

    trydir = os.path.join(args.finaldir, tryname)
    if not os.access(trydir, os.F_OK):
        sys.exit(tryname + "doesn't exist.")

    cwdstatuspath = os.path.join(trydir, config.getFinalStatusFileName())
    cwdstatus = utils.load_status(cwdstatuspath)
    if not utils.check_epoch(cwdstatus, 'Prune'):
        raise utils.EpochError('Please tryprune first.')

    if utils.check_epoch(cwdstatus, 'Evaluate'):
        sys.exit('already evaluated.')

    print('checking')
    checkData()

    modelfile = os.path.join(trydir, config.getFinalModelFileName())
    destfile = os.path.join(libpinyin_dir, 'data', \
                                config.getFinalModelFileName())

    utils.copyfile(modelfile, destfile)
Exemplo n.º 34
0
#from slicer_cura import CuraPrintFile
#from slicer_kisslicer import KissPrintFile
from slicer_simplify3d import Simplify3dGCodeFile
from slicer_prusa_slic3r import PrusaSlic3rCodeFile

from logger import Logger
from switch_tower import PEEK, PTFE, E3DV6, HW_CONFIGS
from switch_tower import AUTO, LEFT, RIGHT, TOP, BOTTOM, TOWER_POSITIONS
from switch_tower import LINES, LINE_COUNT_DEFAULT

import utils

prog_dir = os.path.dirname(os.path.realpath(__file__))

status_file = os.path.join(prog_dir, '.status')
status = utils.load_status(status_file)

version = "0.13"


def detect_file_type(gcode_file, log):
    with open(gcode_file, 'r') as gf:
        line1 = gf.readline()
        if line1.startswith('; G-Code generated by Simplify3D(R)'):
            log.info("Detected Simplify3D format")
            return Simplify3dGCodeFile
        #elif line1.startswith('; KISSlicer'):
        #    log.info("Detected KISSlicer format")
        #    return KissPrintFile
        #elif line1.startswith('; CURA'):
        #    log.info("Detected Cura format")
Exemplo n.º 35
0
envs = []
for i in range(args.procs):
    env = gym.make(args.env)
    env.seed(args.seed + 10000 * i)
    envs.append(env)

# Define obss preprocessor

obss_preprocessor = utils.ObssPreprocessor(run_dir, envs[0].observation_space)

# Define actor-critic model

if utils.model_exists(run_dir):
    acmodel = utils.load_model(run_dir)
    status = utils.load_status(run_dir)
    logger.info("Model successfully loaded\n")
else:
    acmodel = ACModel(obss_preprocessor.obs_space, envs[0].action_space,
                      not args.no_instr, not args.no_mem)
    status = {"num_frames": 0, "i": 0}
    logger.info("Model successfully created\n")
logger.info("{}\n".format(acmodel))

if torch.cuda.is_available():
    acmodel.cuda()
logger.info("CUDA available: {}\n".format(torch.cuda.is_available()))

# Define actor-critic algo

if args.algo == "a2c":
Exemplo n.º 36
0
def train(args):
    # make dataset for train and validation
    assert args.lr_train_path is not None
    assert args.hr_train_path is not None
    assert args.lr_val_path is not None
    assert args.hr_val_path is not None
    # patch the train data for training
    train_dataset = SRDataset(lr_path=args.lr_train_path,
                              hr_path=args.hr_train_path,
                              patch_size=args.patch_size,
                              scale=args.scale,
                              aug=args.augment,
                              normalization=args.normalization,
                              need_patch=True,
                              suffix=args.suffix)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.n_threads)

    val_dataset = SRDataset(lr_path=args.lr_val_path,
                            hr_path=args.hr_val_path,
                            patch_size=args.patch_size,
                            scale=args.scale,
                            normalization=args.normalization,
                            need_patch=True,
                            suffix=args.suffix)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.n_threads)

    # chech log
    check_logs(args)
    writer = SummaryWriter(log_dir=args.tblog)
    # check for gpu
    device = check_hardware(args)
    # check the model
    module = import_module('model.' + args.model.lower())
    model = module.wrapper(args)

    # continue train or not
    start_epoch = 0
    best_val_psnr = -1.0
    best_val_loss = 1e8
    if args.continue_train:
        status_ = load_status(args.status_logger)
        args.lr = status_['lr']
        start_epoch = status_['epoch']
        best_val_loss = status_['best_val_loss']

        pretrained_dict = torch.load(status_['last_weight_pth'])
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        logger.info(
            f"Load model from {status_['last_weight_pth']} for continuing train."
        )

    if not args.cpu:
        model = model.to(device)
    # check the optimizer
    optimizer = check_optimizer_(args, model)
    # check the lr schedule
    lr_schedule = StepLR(optimizer, args.decay_step, args.gamma)
    # check the loss
    criterion = check_loss_(args)

    # for iteration to train the model and validation for every epoch
    for epoch in range(start_epoch, args.epochs):
        torch.cuda.empty_cache()

        train_loss = 0.0
        model.train()
        for batch, data in enumerate(train_dataloader):
            x = data['lr']
            y = data['hr']
            x = x.to(device)
            y = y.to(device)

            # perform forward calculation
            y_hat = model(x)
            loss_ = criterion(y_hat, y)
            train_loss += loss_.item()
            logger.info("Epoch-%d-Batch-%d, train loss: %.4f" %
                        (epoch, batch, loss_.item()))
            writer.add_scalar(f'Train/Batchloss',
                              loss_.item(),
                              global_step=epoch * (len(train_dataloader)) +
                              batch)

            # perform backward calculation
            optimizer.zero_grad()
            loss_.backward()
            # perform gradient clipping
            if args.gclip > 0:
                nn.utils.clip_grad_value_(model.parameters(), args.gclip)
            optimizer.step()
        train_loss = train_loss / (batch + 1)
        logger.info("Epoch-%d, train loss: %.4f" % (epoch, train_loss))
        writer.add_scalar(f'Train/Epochloss', train_loss, global_step=epoch)

        # validation
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            val_psnr = 0.0
            for batch, data in enumerate(val_dataloader):
                x = data['lr']
                y = data['hr']
                x = x.to(device)
                y = y.to(device)

                y_hat = model(x)
                loss_ = criterion(y_hat, y)
                val_loss += loss_.item()

                # save the intermedia result for visualization
                y = y[0].detach().cpu().numpy()
                y_hat = y_hat[0].detach().cpu().numpy()
                y = np.transpose(y, (1, 2, 0))
                y_hat = np.transpose(y_hat, (1, 2, 0))
                # if args.normalization == 1:
                #     y = y * 255.0
                #     y_hat = y_hat * 255.0
                y = denormalize_(y, args.normalization)
                y_hat = denormalize_(y_hat, args.normalization)
                # clip is really important, otherwise the anomaly rgb noise data exists
                y = np.clip(y, 0.0, 255.0)
                y_hat = np.clip(y_hat, 0.0, 255.0)

                _res = np.concatenate([y_hat, y], axis=1).astype(np.uint8)
                cv2.imwrite(
                    os.path.join(args.log_img_root, f'{epoch}_{batch}.png'),
                    _res)

            val_loss = val_loss / (batch + 1)
            logger.info("Epoch-%d, validation loss: %.4f" % (epoch, val_loss))
            writer.add_scalar(f'Val/loss', val_loss, global_step=epoch)

        # adjust the learning rate
        lr_schedule.step(epoch=epoch)
        writer.add_scalar(f'Train/lr',
                          lr_schedule.get_lr()[0],
                          global_step=epoch)

        # save the best validation psnr model parameters
        if best_val_loss > val_loss:
            best_val_loss = val_loss
            model.eval().cpu()
            torch.save(model.state_dict(), args.weight_pth)
            logger.info(f"Save {args.weight_pth}")
            model.to(device).train()

        # log the training status
        model.eval().cpu()
        torch.save(model.state_dict(), args.status_pth)
        model.to(device).train()
        status_ = {
            'epoch': epoch,
            'lr': lr_schedule.get_lr()[0],
            'best_val_loss': best_val_loss,
            'last_weight_pth': args.status_pth,
        }
        log_status(args.status_logger, **status_)
Exemplo n.º 37
0
def train_i2a_model(environment_class,  # name of the environment to train on (REQUIRED)
                    environment_model_name,  # class
                    algorithm,
                    imagination_steps,
                    seed=1,  # random seed (default: 1)
                    procs=16,  # number of processes (default: 16)
                    frames=10 ** 7,  # number of frames of training (default: 10e7)
                    log_interval=1,  # number of updates between two logs (default: 1)
                    save_interval=10,  # number of updates between two saves (default: 0, 0 means no saving)
                    frames_per_proc=None,  # number of frames per process before update (default: 5 for A2C and 128 for PPO)
                    discount=0.99,  # discount factor (default: 0.99)
                    lr=7e-4,  # learning rate for optimizers (default: 7e-4)
                    gae_lambda=0.95,  # lambda coefficient in GAE formula (default: 0.95, 1 means no gae)
                    entropy_coef=0.01,  # entropy term coefficient (default: 0.01)
                    value_loss_coef=0.5,  # value loss term coefficient (default: 0.5)
                    max_grad_norm=0.5,  # maximum norm of gradient (default: 0.5)
                    recurrence=1,  # number of steps the gradient is propagated back in time (default: 1)
                    optim_eps=1e-5,  # Adam and RMSprop optimizer epsilon (default: 1e-5)
                    optim_alpha=0.99,  # RMSprop optimizer apha (default: 0.99)
                    clip_eps=0.2,  # clipping epsilon for PPO (default: 0.2)
                    epochs=4,  # number of epochs for PPO (default: 4)
                    batch_size=256,  # batch size for PPO (default: 256)
                    no_instr=False,  # don't use instructions in the model
                    no_mem=False,  # don't use memory in the model
                    note=None,  # name suffix
                    tensorboard=True):
    saved_arguments = locals()

    date_suffix = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
    note = note + "_" if note else ""

    model_name = "I2A-{}_{}{}_s{}_{}".format(imagination_steps, note, environment_name(environment_class), seed, date_suffix)
    model_dir = utils.get_model_dir(model_name)

    # Define logger, CSV writer and Tensorboard writer
    logger = utils.get_logger(model_dir)
    csv_file, csv_writer = utils.get_csv_writer(model_dir)

    if tensorboard:
        from tensorboardX import SummaryWriter
        tb_writer = SummaryWriter(model_dir)

    # Log command and all script arguments
    logger.info("{}\n".format(saved_arguments))

    # Set seed for all randomness sources
    utils.seed(seed)

    # Load training status
    try:
        status = utils.load_status(model_dir)
    except OSError:
        status = {"num_frames": 0, "update": 0}

    # Define actor-critic model

    num_frames = status["num_frames"]
    total_start_time = time.time()
    update = status["update"]

    environment_model = utils.load_model(utils.get_model_dir(environment_model_name))
    i2a_model = I2AModel(environment_class, environment_model, imagination_steps)

    algorithm.load_acmodel(i2a_model)

    logger.info("Using environment model: {}\n".format(environment_model_name))
    logger.info("{}\n".format(environment_model))

    logger.info("Agent architecture:\n")
    logger.info("{}\n".format(i2a_model))

    while num_frames < frames:
        # Update model parameters

        update_start_time = time.time()
        logs = algorithm.update_parameters()
        update_end_time = time.time()

        num_frames += logs["num_frames"]
        update += 1

        # Print logs

        if update % log_interval == 0:
            fps = logs["num_frames"] / (update_end_time - update_start_time)
            duration = int(time.time() - total_start_time)
            return_per_episode = utils.synthesize(logs["return_per_episode"])
            rreturn_per_episode = utils.synthesize(logs["reshaped_return_per_episode"])
            num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"])

            header = ["update", "frames", "FPS", "duration"]
            data = [update, num_frames, fps, duration]
            header += ["rreturn_" + key for key in rreturn_per_episode.keys()]
            data += rreturn_per_episode.values()
            header += ["num_frames_" + key for key in num_frames_per_episode.keys()]
            data += num_frames_per_episode.values()
            header += ["entropy", "value", "policy_loss", "value_loss", "grad_norm", "distillation_loss"]
            data += [logs["entropy"], logs["value"], logs["policy_loss"], logs["value_loss"], logs["grad_norm"], logs["distillation_loss"]]

            logger.info(
                "U {} | F {:06} | FPS {:04.0f} | D {} | rR:x̄σmM {:.2f} {:.2f} {:.2f} {:.2f} | F:x̄σmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f} | dL {:.3f}"
                    .format(*data))

            header += ["return_" + key for key in return_per_episode.keys()]
            data += return_per_episode.values()

            if status["num_frames"] == 0:
                csv_writer.writerow(header)
            csv_writer.writerow(data)
            csv_file.flush()

            if tensorboard:
                for field, value in zip(header, data):
                    tb_writer.add_scalar(field, value, num_frames)

            status = {"num_frames": num_frames, "update": update}
            utils.save_status(status, model_dir)

        # Save vocabulary and model

        if save_interval > 0 and update % save_interval == 0:
            utils.save_model(algorithm.acmodel, model_dir)
            logger.info("Model successfully saved")
Exemplo n.º 38
0
def run(full_args: Namespace, return_models: bool = False):
    if sys.argv[0].startswith("train"):
        import os
        full_args.out_dir = os.path.dirname(sys.argv[1])

    args = full_args.main
    agent_args = full_args.agent
    model_args = full_args.model
    extra_logs = getattr(full_args, "extra_logs", None)
    main_r_key = getattr(full_args, "main_r_key", None)

    if args.seed == 0:
        args.seed = full_args.run_id + 1
    max_eprews = args.max_eprews
    max_eprews_window = getattr(args, "max_eprews_window", 1)

    post_process_args(agent_args)
    post_process_args(model_args)

    model_dir = getattr(args, "model_dir", full_args.out_dir)
    print(model_dir)

    # ==============================================================================================
    # @ torc_rl repo original

    # Define logger, CSV writer and Tensorboard writer

    logger = utils.get_logger(model_dir)
    csv_file, csv_writer = utils.get_csv_writer(model_dir)
    tb_writer = None
    if args.tb:
        from tensorboardX import SummaryWriter
        tb_writer = SummaryWriter(model_dir)

    # Log command and all script arguments

    logger.info("{}\n".format(" ".join(sys.argv)))
    logger.info("{}\n".format(args))

    # ==============================================================================================
    # Set seed for all randomness sources
    utils.seed(args.seed)

    # ==============================================================================================
    # Generate environments

    envs = []

    # Get env wrappers - must be a list of elements
    wrapper_method = getattr(full_args.env_cfg, "wrapper", None)
    if wrapper_method is None:

        def idem(x):
            return x

        env_wrapper = idem
    else:
        env_wrappers = [getattr(gym_wrappers, w_p) for w_p in wrapper_method]

        def env_wrapp(w_env):
            for wrapper in env_wrappers[::-1]:
                w_env = wrapper(w_env)
            return w_env

        env_wrapper = env_wrapp

    actual_procs = getattr(args, "actual_procs", None)
    no_actions = getattr(full_args.env_cfg, "no_actions", 6)

    if actual_procs:
        # Split envs in chunks
        no_envs = args.procs
        envs, chunk_size = get_envs(full_args,
                                    env_wrapper,
                                    no_envs,
                                    n_actions=no_actions)
        first_env = envs[0][0]
        print(
            f"NO of envs / proc: {chunk_size}; No of processes {len(envs[1:])} + Master"
        )
    else:
        for i in range(args.procs):
            env = env_wrapper(gym.make(args.env))
            env.max_steps = full_args.env_cfg.max_episode_steps

            env.seed(args.seed + 10000 * i)
            envs.append(env)
        first_env = envs[0]

    # Generate evaluation envs
    eval_envs = []
    eval_episodes = getattr(full_args.env_cfg, "eval_episodes", 0)
    if full_args.env_cfg.no_eval_envs > 0:
        no_envs = full_args.env_cfg.no_eval_envs
        eval_envs, chunk_size = get_envs(full_args,
                                         env_wrapper,
                                         no_envs,
                                         n_actions=no_actions)

    # Define obss preprocessor
    max_image_value = full_args.env_cfg.max_image_value
    normalize_img = full_args.env_cfg.normalize
    permute = getattr(full_args.env_cfg, "permute", False)
    obss_preprocessor = getattr(full_args.env_cfg, "obss_preprocessor", None)
    obs_space, preprocess_obss = utils.get_obss_preprocessor(
        args.env,
        first_env.observation_space,
        model_dir,
        max_image_value=max_image_value,
        normalize=normalize_img,
        permute=permute,
        type=obss_preprocessor)

    first_obs = first_env.reset()
    if "state" in first_obs:
        full_state_size = first_obs["state"].shape

        # Add full size shape
        add_to_cfg(full_args, MAIN_CFG_ARGS, "full_state_size",
                   full_state_size)

    if "position" in first_obs:
        position_size = first_obs["position"].shape

        # Add full size shape
        add_to_cfg(full_args, MAIN_CFG_ARGS, "position_size", position_size)

    # Add the width and height of environment for position estimation
    model_args.width = first_env.unwrapped.width
    model_args.height = first_env.unwrapped.height

    # ==============================================================================================
    # Load training status
    try:
        status = utils.load_status(model_dir)
    except OSError:
        status = {"num_frames": 0, "update": 0}

    saver = utils.SaveData(model_dir,
                           save_best=args.save_best,
                           save_all=args.save_all)
    model, agent_data, other_data = None, dict(), None
    try:
        # Continue from last point
        model, agent_data, other_data = saver.load_training_data(best=False)
        logger.info("Training data exists & loaded successfully\n")
    except OSError:
        logger.info("Could not load training data\n")

    # ==============================================================================================
    # Load Model

    if model is None:
        model = get_model(model_args,
                          obs_space,
                          first_env.action_space,
                          use_memory=model_args.mem)
        logger.info(f"Model [{model_args.name}] successfully created\n")

        # Print Model info
        logger.info("{}\n".format(model))

    if torch.cuda.is_available():
        model.cuda()
    logger.info("CUDA available: {}\n".format(torch.cuda.is_available()))

    # ==============================================================================================
    # Load Agent

    algo = get_agent(full_args.agent,
                     envs,
                     model,
                     agent_data,
                     preprocess_obss=preprocess_obss,
                     reshape_reward=None,
                     eval_envs=eval_envs,
                     eval_episodes=eval_episodes)

    has_evaluator = hasattr(algo,
                            "evaluate") and full_args.env_cfg.no_eval_envs > 0

    if return_models:
        return algo, model, envs, saver

    # ==============================================================================================
    # Train model

    prev_rewards = []
    crt_eprew = 0
    if "eprew" in other_data:
        crt_eprew = other_data["eprew"]
    num_frames = status["num_frames"]
    total_start_time = time.time()
    update = status["update"]
    update_start_time = time.time()

    while num_frames < args.frames:
        # Update model parameters

        logs = algo.update_parameters()

        num_frames += logs["num_frames"]
        update += 1

        if update % args.eval_interval == 0 and has_evaluator:
            eval_logs = algo.evaluate(eval_key=main_r_key)
            logs.update(eval_logs)

        prev_start_time = update_start_time
        update_start_time = time.time()

        # Print logs
        if update % args.log_interval == 0:
            fps = logs["num_frames"] / (update_start_time - prev_start_time)
            duration = int(time.time() - total_start_time)
            return_per_episode = utils.synthesize(logs["return_per_episode"])
            rreturn_per_episode = utils.synthesize(
                logs["reshaped_return_per_episode"])
            num_frames_per_episode = utils.synthesize(
                logs["num_frames_per_episode"])

            header = ["update", "frames", "FPS", "duration"]
            data = [update, num_frames, fps, duration]
            header += ["rreturn_" + key for key in rreturn_per_episode.keys()]
            data += rreturn_per_episode.values()
            header += [
                "num_frames_" + key for key in num_frames_per_episode.keys()
            ]
            data += num_frames_per_episode.values()
            header += ["entropy", "value", "policy_loss", "value_loss"]
            data += [
                logs["entropy"], logs["value"], logs["policy_loss"],
                logs["value_loss"]
            ]
            header += ["grad_norm"]
            data += [logs["grad_norm"]]

            # add log fields that are not in the standard log format (for example value_int)
            extra_fields = extra_log_fields(header, list(logs.keys()))
            header.extend(extra_fields)
            data += [logs[field] for field in extra_fields]

            # print to stdout the standard log fields + fields required in config
            keys_format, printable_data = print_keys(header, data, extra_logs)
            logger.info(keys_format.format(*printable_data))

            header += ["return_" + key for key in return_per_episode.keys()]
            data += return_per_episode.values()

            if status["num_frames"] == 0:
                csv_writer.writerow(header)
            csv_writer.writerow(data)
            csv_file.flush()

            if args.tb:
                for field, value in zip(header, data):
                    tb_writer.add_scalar(field, value, num_frames)

            status = {"num_frames": num_frames, "update": update}

            if main_r_key is None:
                crt_eprew = list(rreturn_per_episode.values())[0]
                prev_rewards.append(crt_eprew)
            else:
                crt_eprew = logs[main_r_key]
                prev_rewards.append(logs[main_r_key])

        # -- Save vocabulary and model

        if args.save_interval > 0 and update % args.save_interval == 0:
            preprocess_obss.vocab.save()

            saver.save_training_data(model, algo.get_save_data(), crt_eprew)

            logger.info("Model successfully saved")

            utils.save_status(status, model_dir)

        check_rew = np.mean(prev_rewards[-max_eprews_window:])
        if len(prev_rewards) > max_eprews_window and check_rew > max_eprews:
            print(
                f"Reached mean return {max_eprews} for a window of {max_eprews_window} steps"
            )
            exit()
Exemplo n.º 39
0
def handleOneIndex(indexpath, subdir, indexname, fast):
    inMemoryFile = "model.db"

    modeldir = os.path.join(config.getModelDir(), subdir, indexname)
    os.path.exists(modeldir) or os.makedirs(modeldir)


    def cleanupInMemoryFile():
        modelfile = os.path.join(config.getInMemoryFileSystem(), inMemoryFile)
        reportfile = modelfile + config.getReportPostfix()
        if os.access(modelfile, os.F_OK):
            os.unlink(modelfile)
        if os.access(reportfile, os.F_OK):
            os.unlink(reportfile)

    def copyoutInMemoryFile(modelfile):
        inmemoryfile = os.path.join\
            (config.getInMemoryFileSystem(), inMemoryFile)
        inmemoryreportfile = inmemoryfile + config.getReportPostfix()
        reportfile = modelfile + config.getReportPostfix()

        if os.access(inmemoryfile, os.F_OK):
            utils.copyfile(inmemoryfile, modelfile)
        if os.access(inmemoryreportfile, os.F_OK):
            utils.copyfile(inmemoryreportfile, reportfile)

    def cleanupFiles(modelnum):
        modeldir = os.path.join(config.getModelDir(), subdir, indexname)
        modelfile = os.path.join( \
            modeldir, config.getCandidateModelName(modelnum))
        reportfile = modelfile + config.getReportPostfix()
        if os.access(modelfile, os.F_OK):
            os.unlink(modelfile)
        if os.access(reportfile, os.F_OK):
            os.unlink(reportfile)

    def storeModelStatus(modelfile, textnum, nexttextnum):
        #store model info in status file
        modelstatuspath = modelfile + config.getStatusPostfix()
        #create None status
        modelstatus = {}
        modelstatus['GenerateStart'] = textnum
        modelstatus['GenerateEnd'] = nexttextnum
        utils.sign_epoch(modelstatus, 'Generate')
        utils.store_status(modelstatuspath, modelstatus)

    print(indexpath, subdir, indexname)

    indexstatuspath = indexpath + config.getStatusPostfix()
    indexstatus = utils.load_status(indexstatuspath)
    if not utils.check_epoch(indexstatus, 'MergeSequence'):
        raise utils.EpochError('Please mergeseq first.\n')
    if utils.check_epoch(indexstatus, 'Generate'):
        return

    #continue generating
    textnum, modelnum, aggmodelsize = 0, 0, 0
    if 'GenerateTextEnd' in indexstatus:
        textnum = indexstatus['GenerateTextEnd']
    if 'GenerateModelEnd' in indexstatus:
        modelnum = indexstatus['GenerateModelEnd']

    #clean up previous file
    if fast:
        cleanupInMemoryFile()

    cleanupFiles(modelnum)

    #begin processing
    indexfile = open(indexpath, 'r')
    for i, oneline in enumerate(indexfile.readlines()):
        #continue last generating
        if i < textnum:
            continue

        #remove trailing '\n'
        oneline = oneline.rstrip(os.linesep)
        (title, textpath) = oneline.split('#')
        infile = config.getTextDir() + textpath
        infilesize = utils.get_file_length(infile + config.getMergedPostfix())
        if infilesize < config.getMinimumFileSize():
            print("Skipping " + title + '#' + textpath)
            continue

        if fast:
            modelfile = os.path.join(config.getInMemoryFileSystem(), \
                                         inMemoryFile)
        else:
            modelfile = os.path.join(modeldir, \
                                         config.getCandidateModelName(modelnum))

        reportfile = modelfile + config.getReportPostfix()
        print("Proccessing " + title + '#' + textpath)
        if generateOneText(infile, modelfile, reportfile):
            aggmodelsize += infilesize
        print("Processed " + title + '#' + textpath)
        if aggmodelsize > config.getCandidateModelSize():
            #copy out in memory file
            if fast:
                modelfile = os.path.join\
                    (modeldir, config.getCandidateModelName(modelnum))
                copyoutInMemoryFile(modelfile)
                cleanupInMemoryFile()

            #the model file is in disk now
            nexttextnum = i + 1
            storeModelStatus(modelfile, textnum, nexttextnum)

            #new model candidate
            aggmodelsize = 0
            textnum = nexttextnum
            modelnum += 1

            #clean up next file
            cleanupFiles(modelnum)

            #save current progress in status file
            indexstatus['GenerateTextEnd'] = nexttextnum
            indexstatus['GenerateModelEnd'] = modelnum
            utils.store_status(indexstatuspath, indexstatus)


    #copy out in memory file
    if fast:
        modelfile = os.path.join\
            (modeldir, config.getCandidateModelName(modelnum))
        copyoutInMemoryFile(modelfile)
        cleanupInMemoryFile()

    #the model file is in disk now
    nexttextnum = i + 1
    storeModelStatus(modelfile, textnum, nexttextnum)

    indexfile.close()
    #end processing

    #save current progress in status file
    modelnum += 1
    indexstatus['GenerateTextEnd'] = nexttextnum
    indexstatus['GenerateModelEnd'] = modelnum

    utils.sign_epoch(indexstatus, 'Generate')
    utils.store_status(indexstatuspath, indexstatus)