Exemplo n.º 1
0
    def predict(self, data):
        inmsg = data.split('\n')
        #inmsg = ['エージェントが可能な変更を行った後動作を開始しました']

        log.debug('predict: predict data=%s', inmsg)
        dictfile = self.model_path + '/dict.dat'
        configFile = self.model_path + '/analog.ini'
        dict = MyDict(dictfile)
        dict.loadict()
        print 'dict count: %d' % dict.countDictWords()

        ## 採用メッセージID
        acceptid = self.fload(self.model_path + '/accept.dat')
        print '**accept id:'
        print acceptid

        ## エラー定義の読み込み
        with open(self.model_path + '/error.json', 'r') as f:
            errjson = json.load(f, "utf-8")

        ## RNNメッセージID化
        vocab = 10000
        dim = 100
        y = 94
        train = Trainer(vocab, dim, y)
        train.load(self.model_path + '/train')

        labellist = []  ## ラベル一覧(出力)
        #msglist = []    ## メッセージ一覧(無視リスト作成のため使用メッセージを保存)
        analizeRnnInputList = []
        analizeRnnLabelList = []
        msgidlist = []
        for msgtxt in inmsg:
            ## wakachiで処理するため明示的にunicodeにする
            ## ex) cmapeerdが停止したためプロセスを再起動します
            #w = u'' + msgtxt
            #words = dict.wakachi(w.encode('utf-8'))
            words = dict.wakachi(msgtxt)
            ids = dict.words2ids(words)
            ## 改行のみはスキップ
            if ids == []:
                continue
            ## ex) [113, 63, 11, 9, 13, 104, 6, 7, 66, 8, 9, 10]
            y = train.predict(ids[::-1])
            wordid = y.data.argmax(1)[0]  ## ex) 84

            ## エラーリストにあるメッセージ以外は処理にしない
            if wordid in acceptid:
                msgidlist.append(wordid)
            else:
                print 'not accept id:%d, msg:%s' % (wordid, msgtxt)
                continue
            print '** msg prdict %d, %2f' % (wordid, y.data[0][wordid]
                                             )  ## 選択されたidの確率表示
        pp(inmsg)
        pp(msgidlist)

        ## 対象メッセージが無い
        if msgidlist == []:
            result = [{'score': 100, 'id': None, 'label': None}]
            return result

        ## エラーケースの予測
        config = ConfigParser.SafeConfigParser()
        config.read(configFile)

        dim_in = config.getint('analize', 'dim_in')
        dim_mid = config.getint('analize', 'dim_mid')
        dim_out = config.getint('analize', 'dim_out')
        """
        train = AnazlizeTrainer(dim_in, dim_mid, dim_out)
        train.load(self.model_path + '/train_analize')
        y = train.predict(msgidlist[::-1])
        print y.data.argmax(1)[0]
        rank = y.data.argsort()[0]
        uprank = map(int, rank[::-1])
        print uprank
        #print y.data[0]
        """

        train = AnazlizeTrainer(dim_in, dim_mid, dim_out)
        train.load(self.model_path + '/train_analize')
        #y = train.predict(msgidlist[::-1])

        targetlist = msgidlist
        resultlist = np.zeros(dim_out)  ## 確率結果の最大値を格納する配列
        print resultlist
        for i in range(len(targetlist)):
            target = targetlist[i:]
            y = train.predict(target[::-1])
            print target
            print y.data[0]
            for i in range(len(y.data[0])):
                if y.data[0][i] > resultlist[i]:
                    resultlist[i] = y.data[0][i]

        print resultlist
        #print y.data.argmax(1)[0]
        #rank = y.data.argsort()[0]
        rank = resultlist.argsort()
        uprank = map(int, rank[::-1])
        print uprank
        #print y.data[0]

        result = []
        for i in uprank:
            print '%d, %2f' % (i, resultlist[i])
            item = {
                'score': round(float(resultlist[i]) * 100, 2),
                'id': i,
                'label': errjson[i]['label']
            }
            result.append(item)

        return result
Exemplo n.º 2
0
def fsave(filename, data):
    f = open(filename, "w")
    pickle.dump(data, f)
    f.close()
    return


## Test Main
if __name__ == "__main__":
    dictfile = 'dict.dat'
    configFile = 'analog.ini'
    acceptFile = 'accept.dat'   ## 採用するメッセージid一覧
    dict = MyDict(dictfile)
    dict.loadict()
    print 'dict count: %d' % dict.countDictWords()
    
    ## RNNメッセージID化
    vocab=10000
    dim=100
    y=94
    train = Trainer(vocab, dim, y)
    train.load('train')
    
    labellist = []  ## ラベル一覧(出力)
    #msglist = []    ## メッセージ一覧(無視リスト作成のため使用メッセージを保存)
    analizeRnnInputList = []
    analizeRnnLabelList = []
    
    arg1 = sys.argv[1] if len(sys.argv) == 2 else None
    ## 学習 ##