def train():
    print 'train start '+'.'*30
    #grocery=Grocery('sample')
    grocery=Grocery('version1.0')
    grocery.train(trainlist)
    grocery.save()
    print 'train end '+'.'*30
示例#2
0
def tGrocery():
    outFile = open('testResult.tmp', 'w')
    [trainingSet, benchmark] = pickle.load(open('SampleSeg.pk'))
    testingSet = []
    correctLabel = []
    for i in xrange(len(benchmark)):
        print '%d out of %d' % (i, len(benchmark))
        testingSet.append(benchmark[i][1])
        correctLabel.append(benchmark[i][0]) 
    grocery = Grocery('test')
    grocery.train(trainingSet)
    grocery.save()
    # load
    new_grocery = Grocery('test')
    new_grocery.load()
    Prediction = []
    for i in xrange(len(testingSet)):
        print '%d out of %d' % (i, len(testingSet))
        prediction = new_grocery.predict(testingSet[i])
        Prediction.append(prediction)
        temp = correctLabel[i] + '<-->' + prediction + '  /x01' + testingSet[i] + '\n'
        outFile.write(temp)
    correct = 0
    for i in xrange(len(Prediction)):
        print Prediction[i], correctLabel[i],
        if Prediction[i] == correctLabel[i]:
            correct += 1
            print 'Correct'
        else:
            print 'False'
    print 'Correct Count:', correct
    print 'Accuracy: %f' % (1.0 * correct / len(Prediction))
示例#3
0
class GroceryModel(object):
    def __init__(self):
        self.grocery = Grocery('TextClassify')
    
    def train(self,train_file):
        f = open(train_file,'r')
        line = f.readline().decode('utf8')
        dataset = []
        while line:
            tmp = line.split('\t')
            dataset.append((tmp[0],''.join(tmp[1:])))
            line = f.readline().decode('utf8')
        f.close()
        self.grocery.train(dataset)
        self.grocery.save()
    
    def load_model(self):
        self.grocery.load()
    
    def test(self,test_src):
        self.load_model()
        f = open(test_src,'r')
        line = f.readline().decode('utf8')
        dataset = []
        while line:
            tmp = line.split('\t')
            dataset.append((tmp[0],''.join(tmp[1:])))
            line = f.readline().decode('utf8')
        f.close()
        result = self.grocery.test(dataset)
        print result
    
    def predict(self,text):
        print self.grocery.predict(text)
示例#4
0
class AutoGrocery(object):
    """

    """
    def __init__(self, name, train_data):
        self._train_data = train_data
        self._grocery = Grocery(project_dir + '/models/model_data/' + name)

    def train(self):
        self._grocery.train(self._train_data)

    def save(self):
        self._grocery.save()

    def load(self):
        self._grocery.load()

    def predicate(self, src):
        if not self._grocery.get_load_status():
            try:
                self.load()
            except ValueError:
                self.train()
                self.save()
        pr = self._grocery.predict(src)
        label = pr.predicted_y
        return label, pr.dec_values[label]
示例#5
0
文件: Classify.py 项目: TimePi/Python
 def __train__model__():
     dataframe = pd.read_excel(Classify.__FILE_PATH__)
     data = dataframe[[u'类型',	u'释义']]
     train_data = [(x[0],x[1]) for x in data.values]
     
     grocery = Grocery('Classify')
     
     grocery.train(train_data)
     grocery.save()
     Classify.__MODEL__ = grocery
示例#6
0
    def __train__model__():
        dataframe = pd.read_excel(Classify.__FILE_PATH__)
        data = dataframe[[u'类型', u'释义']]
        train_data = [(x[0], x[1]) for x in data.values]

        grocery = Grocery('Classify')

        grocery.train(train_data)
        grocery.save()
        Classify.__MODEL__ = grocery
示例#7
0
文件: grocery.py 项目: SwoJa/ruman
def test_grocery():
    grocery = Grocery('model_redian')
    grocery.train('trdata_4.txt')
    grocery.save()
    new_grocery = Grocery('model_redian')
    new_grocery.load()
    test_result = new_grocery.test('tedata_4.txt')
    print test_result.accuracy_labels
    print test_result.recall_labels
    test_result.show_result()
示例#8
0
def test_grocery():
    grocery = Grocery('model_redian')
    grocery.train('trdata_4.txt')
    grocery.save()
    new_grocery = Grocery('model_redian')
    new_grocery.load()
    test_result = new_grocery.test('tedata_4.txt')
    print test_result.accuracy_labels
    print test_result.recall_labels
    test_result.show_result()
示例#9
0
 def test_main(self):
     grocery = Grocery(self.grocery_name)
     grocery.train(self.train_src)
     grocery.save()
     new_grocery = Grocery('test')
     new_grocery.load()
     assert grocery.get_load_status()
     assert grocery.predict('考生必读:新托福写作考试评分标准') == 'education'
     # cleanup
     if self.grocery_name and os.path.exists(self.grocery_name):
         shutil.rmtree(self.grocery_name)
示例#10
0
 def test_main(self):
     grocery = Grocery(self.grocery_name)
     grocery.train(self.train_src)
     grocery.save()
     new_grocery = Grocery('test')
     new_grocery.load()
     assert grocery.get_load_status()
     assert grocery.predict('考生必读:新托福写作考试评分标准') == 'education'
     # cleanup
     if self.grocery_name and os.path.exists(self.grocery_name):
         shutil.rmtree(self.grocery_name)
def sentiment_train(gro_name, train_set):
    """

    :param gro_name:
    :param train_set:
    :return:
    """
    gro_ins = Grocery(gro_name)
    # gro_ins.load()
    gro_ins.train(train_set)
    print("Is trained? ", gro_ins.get_load_status())
    gro_ins.save()
示例#12
0
def sentiment_train(gro_name, train_set):
    """
    tgGrocery svm train
    :param gro_name:
    :param train_set:
    :return:
    """
    gro_ins = Grocery(gro_name)
    # gro_ins.load()
    gro_ins.train(train_set)
    print("Is trained? ", gro_ins.get_load_status())
    gro_ins.save()
示例#13
0
def train(train_origin_path, fold):
    grocery = Grocery('cv_' + str(fold) +
                      '_model')  #, custom_tokenize=segment)

    train_src = []
    with open(train_origin_path) as f:
        for line in f:
            label, text = line.strip().split("|text|")
            label = yiji_label[classify_dict[label]]
            train_src.append((label, text))

    grocery.train(train_src)
    grocery.save()
示例#14
0
    def train_phrasing_and_save(self, trainsets=all):
        '''

        :param trainsets:
        :param model_name:
        :return:
        '''
        try:
            grocery = Grocery(self.model_name)
            grocery.train(trainsets)
            grocery.save()
            return True
        except:
            return False
示例#15
0
 def test_main(self):
     grocery = Grocery(self.grocery_name)
     grocery.train(self.train_src)
     grocery.save()
     new_grocery = Grocery('test')
     new_grocery.load()
     assert grocery.get_load_status()
     result = grocery.predict('just a testing')
     print(result)
     result = grocery.predict('考生必读:新托福写作考试评分标准')
     print(result)
     print("type of result is :",type(result))
     assert str(grocery.predict('考生必读:新托福写作考试评分标准')) == 'education'
     assert str(grocery.predict('法网')) == 'sports'
     # cleanup
     if self.grocery_name and os.path.exists(self.grocery_name):
         shutil.rmtree(self.grocery_name)
示例#16
0
class MyGrocery(object):
  def __init__(self, name):
    super(MyGrocery, self).__init__()
    self.grocery = Grocery(name)
    self.loaded = False
    self.correct = 1.0

  def train(self, src):
    lines = []
    for line in csv.reader(open(src)):
      label, s = line[0],line[1]
      text = s.decode('utf8')
      lines.append((label, text))
    self.grocery.train(lines)

  def save_model(self):
    self.grocery.save()

  def train_and_save(self, src):
    self.train(src)
    self.save_model()

  def load_model(self):
    if not self.loaded:
      self.grocery.load()
      self.loaded = True

  def predict(self, text):
    self.load_model()
    return self.grocery.predict(text)

  def test(self, src):
    self.load_model()
    total, wrong_num = 0.0, 0.0
    for line in csv.reader(open(src)):
      total += 1
      if line[0] != self.predict(line[1]):
        wrong_num += 1

    print "load test file from " + src
    correct = (total - wrong_num ) / total
    self.correct = correct
    print "total: %d , wrong_num: %d, success percentage: %f" %(total, wrong_num, correct)
    result = dict(type="test", total=total, wrong_num=wrong_num, correct=correct)
    return json.dumps(result)
示例#17
0
def tgrocery_train(train_data,test_data):
        '''model预测'''
        print("训练语料总数为: " + str(len(train_data)))
        test_corpus, test_label = test_split(test_data)

        grocery = Grocery('TextGrocery')
        print("start training......")
        grocery.train(train_data)
        grocery.save()
        new_grocery = Grocery('TextGrocery')
        new_grocery.load()

        predict_label = []
        for sample in test_corpus:
                label = new_grocery.predict(sample)

                predict_label.append(str(label))
        # print(predict_label)
        return test_corpus,test_label,predict_label
示例#18
0
def learn_model(file_name):
    path = os.path.join(BASE_DIR, 'learn', file_name)
    try:
        df = pd.read_excel(path)
    except Exception as e:
        return {'IsErr': True, 'ErrDesc': u'找不到文档或者读取文档出错'}
    try:
        # 删去缺失值的行
        df = df.dropna(axis=0)
        df = df.apply(split_comment, axis=1)
    except Exception as e:
        return {'IsErr': True, 'ErrDesc': u'文档格式有误,应包含Tag(标签名字),Comment(评价内容)'}

    try:
        # 拆分学习组和测试组 3 :2
        len_learn = len(df) / 5 * 3
        # 生成学习文档和测试文档
        learn_file_name, test_file_name = output_file(file_name, df, len_learn)
        tmp_learn_name = os.path.join(BASE_DIR, 'learn',
                                      'model_' + learn_file_name.split('.')[0])
        grocery = Grocery(tmp_learn_name.encode('utf-8'))
        path = os.path.join(BASE_DIR, 'learn', learn_file_name)
        grocery.train(path.encode('utf-8'))
        grocery.save()
    except Exception as e:
        return {'IsErr': True, 'ErrDesc': u'学习不成功,没有生产新的模型,请再次尝试。'}

    # 测试
    res = test_sample(tmp_learn_name, test_file_name)

    return {
        'IsErr':
        False,
        'ErrDesc':
        u'成功生产新的模型,测试验证的正确率为%s, 模型保存为:%s' %
        (res, os.path.split(tmp_learn_name)[1])
    }
示例#19
0
from tgrocery import Grocery
# 新开张一个杂货铺(别忘了取名)
grocery = Grocery('sample')
# 训练文本可以用列表传入
train_src = [
    ('education', '名师指导托福语法技巧:名词的复数形式'),
    ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
    ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
    ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与')
]
grocery.train(train_src)
# 也可以用文件传入(默认以tab为分隔符,也支持自定义)
#grocery.train('train_ch.txt')
# 保存模型
grocery.save()
# 加载模型(名字和保存的一样)
new_grocery = Grocery('sample')
new_grocery.load()
# 预测
new_grocery.predict('考生必读:新托福写作考试评分标准')
#education

# 测试
test_src = [
    ('education', '福建春季公务员考试报名18日截止 2月6日考试'),
    ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
]
new_grocery.test(test_src)
# 输出测试的准确率
#0.5
示例#20
0
# coding=utf-8
from tgrocery import Grocery

grocery = Grocery('sample')

train_src = [('education', '名师指导托福语法技巧:名词的复数形式'),
             ('education', '中国高考成绩海外认可 是“狼来了”吗?'),
             ('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
             ('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与')]
grocery.train(train_src)
#grocery.train('/home/wangjianfei/git/data/train_ch.txt')
# grocery.train('train_ch.txt')
grocery.save()
new_grocery = Grocery('sample')
new_grocery.load()
print(
    new_grocery.predict(
        'Abbott government spends $8 million on higher education media blitz'))
test_src = [
    ('education', '福建春季公务员考试报名18日截止 2月6日考试'),
    ('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
]
print("start test..................")
#grocery.test('/home/wangjianfei/git/data/test.txt')
# grocery.train('train_ch.txt'))
# custom_grocery = Grocery('custom', custom_tokenize=list)
print(new_grocery.test(test_src))
示例#21
0
def demo_flask(image_file):
    grocery = Grocery('Addrss_NLP')
    model_name=grocery.name
    text_converter=None
    if (os.path.exists(model_name)):
        tgM=GroceryTextModel(text_converter,model_name)
        tgM.load(model_name)
        grocery.model=tgM
        print('load!!!!!')
    else:
        add_file = open('pkl_data/address1.pkl', 'rb')
        other_file = open('pkl_data/others1.pkl', 'rb')
        add_list = pickle.load(add_file)
        other_list = pickle.load(other_file)
        add_file .close()
        other_file .close()
        grocery = Grocery('Addrss_NLP')
        add_list.extend(other_list)
        grocery.train(add_list)
        print (grocery.get_load_status())
        grocery.save()
        # print('train!!!!!!!!')
    addrline = [] 
    t = time.time()
    result_dir = '/data/share/nginx/html/bbox'
    image = np.array(Image.open(image_file).convert('RGB'))
    result, image_framed = ocr_whole.model(image)
    output_file = os.path.join(result_dir, image_file.split('/')[-1])
    Image.fromarray(image_framed).save(output_file)
    ret_total = ''
    for key in result:
        string1 = result[key][1]
        # print("predict line text :",string1)
        string2 = re.sub("[\s+\.\!\/_,$%^*(+\"\']+|[+——!,。?、~@#¥%……&*{}[]+", "", string1)
        no_digit = len(list(filter(str.isdigit, string2)))
        no_alpha = len(list(filter(is_alphabet, string2)))
        if '注册' in string2 or '洼册' in string2 or '洼·册' in string2 or '洼.册' in string2 or '汪·册' in string2 or len(set('登记机关') & set(string2)) >= 3 or '电话' in string2 or ((no_digit / len(string2) > 0.7 and no_digit > 5)):
            predict_result='others'
        elif no_alpha>5 or len(set('经营范围化学品') & set(string2)) >= 3 or len(set('年月日') & set(string2)) >= 2:
            predict_result='others'
        else:
            predict_result = grocery.predict(string2)
        if (str(predict_result) == 'address'):
            string1 = string1.replace('《', '(')
            string1 = string1.replace('》', ')')
            string1 = string1.replace('(', '(')
            string1 = string1.replace(')', ')')
            string1 = string1.replace('((','(')
            if ((not ret_total) or len(string1) > len(ret_total)):
                ret_total = ''
                ret_total += string1
            else:
                ret_total += string1
    
    if ')' in ret_total:
        if '(' not in ret_total:
            ret_total = ret_total.replace('C', '(')
    ret_total = re.sub(r'((\w)住所(.*)', '', ret_total)
    ret_total = re.sub(r'((\w)住房(.*)', '', ret_total)
    ret_total = re.sub(r'(不作为(.*)', '', ret_total)
    ret_total = re.sub(r'(有效期(.*)', '', ret_total)
    ret_total = re.sub(r'(仅限(.*)', '', ret_total)
    ret_total = re.sub(r'(临时经营(.*)', '', ret_total)
    ret_total = re.sub(r'(仅限办公(.*)', '', ret_total)
    ret_total = re.sub(r'(经营场所(.*)', '', ret_total)
    ret_total = re.sub(r"^[经]*[营]*[场/住]*[所]*", "", ret_total)
    ret_total = stupid_revise(ret_total)
    print("Mission complete, it took {:.3f}s".format(time.time() - t))
    print('\nRecongition Result:\n')
    print(ret_total)
    return output_file,ret_total
示例#22
0
class TagPredictor(object):
    def _custom_tokenize(self, line, **kwargs):
        try:
            kwargs["method"]
        except:
            method = str(self.kwargs["method"])
        else:
            method = str(kwargs["method"])
        if method == "normal":
            tokens = self.key_ext.calculateTokens(line,
                                                  doc_len_lower_bound=5,
                                                  doc_len_upper_bound=500,
                                                  method="normal")
        elif method == "processed":
            tokens = line.split(',')
        return tokens

    def __init__(self, *args, **kwargs):
        self.grocery_name = str(kwargs["grocery_name"])
        method = str(kwargs["method"])
        train_src = str(kwargs["train_src"])
        self.PREFIX = conf.load("predict_label")["prefix"]
        self.MODEL_DIR = conf.load("predict_label")["model_dir"]

        self.kwargs = kwargs
        if method == "normal":
            self.key_ext = keyExt()
            self.grocery = Grocery(self.grocery_name,
                                   custom_tokenize=self._custom_tokenize)
        elif method == "jieba":
            self.grocery = Grocery(self.grocery_name)
        elif method == "processed":
            self.grocery = Grocery(self.grocery_name,
                                   custom_tokenize=self._custom_tokenize)
        pass

    def trainFromDocs(self, *args, **kwargs):
        model = self.grocery.train(self.kwargs["train_src"])
        return model

    def autoEvaluation(self, *args, **kwargs):
        prune_threshold = float(kwargs["threshold"])
        excluded_labels = kwargs["excluded_labels"]
        excluded_docs = kwargs["excluded_docs"]

        train_data = []
        with open(self.kwargs["train_src"], 'rb') as f:
            for line in f:
                try:
                    line.split('\t', 1)[1]
                except:
                    continue
                else:
                    train_data.append(
                        (line.split('\t',
                                    1)[0], line.split('\t',
                                                      1)[1].split('\n', 1)[0]))
        f.close()

        print "#items before filtering:", len(train_data)
        print "-- Now we filter out the excluded docs --"
        train_data = [i for i in train_data if i[1] not in excluded_docs]
        print "#items after filtering:", len(train_data)
        print "-- Now we filter out the excluded labels --"
        train_data = [i for i in train_data if i[0] not in excluded_labels]
        print "#items after filtering:", len(train_data)

        n = len(train_data)  #number of rows in your dataset
        indices = range(n)
        indices = shuffle(indices)
        train_set = map(lambda x: train_data[x], indices[:n * 10 // 10])
        test_set = map(lambda x: train_data[x], indices[:n * 10 // 10])

        self.grocery.train(train_set)
        test_result = self.grocery.test(test_set)
        print '-- Accuracy after training --'
        print 'Accuracy, A-0:', test_result

        low_recall_label = []
        for item in test_result.recall_labels.items():
            if item[1] < prune_threshold:
                low_recall_label.append(item[0])
        new_train_set = [
            item for item in train_set if item[0] not in low_recall_label
        ]
        new_test_set = [
            item for item in train_set if item[0] not in low_recall_label
        ]

        self.grocery.train(new_train_set)
        new_test_result = self.grocery.test(new_test_set)

        print '-- Accuracy after training, with low-recall labels (less than', str(
            prune_threshold * 100), '%) pruned --'
        print 'Accuracy, A-1:', new_test_result

        return self.grocery, new_test_result

    def manualEvaluation(self, *args, **kwargs):
        n_docs = int(kwargs["n_docs"])
        excluded_labels = kwargs["excluded_labels"]
        excluded_docs = kwargs["excluded_docs"]

        train_data = []
        with open(self.kwargs["train_src"], 'rb') as f:
            for line in f:
                try:
                    line.split('\t', 1)[1]
                except:
                    continue
                else:
                    train_data.append(
                        (line.split('\t',
                                    1)[0], line.split('\t',
                                                      1)[1].split('\n', 1)[0]))
        f.close()

        train_data = [
            item for item in train_data if item[0] not in excluded_labels
        ]
        train_data = [i for i in train_data if i[1] not in excluded_docs]

        n = len(train_data)  #number of rows in your dataset
        indices = range(n)
        indices = shuffle(indices)
        test_set = map(lambda x: train_data[x], indices[0:n_docs])
        g = self.loadTrainModel()
        test_result = g.test(test_set)
        return test_set, test_result

    def saveTrainModel(self, *args, **kwargs):
        self.grocery.save()
        os.rename(
            self.PREFIX + self.grocery_name + '_train.svm',
            self.PREFIX + self.MODEL_DIR + self.grocery_name + '_train.svm')
        return

    def loadTrainModel(self, *args, **kwargs):
        os.rename(
            self.PREFIX + self.MODEL_DIR + self.grocery_name + '_train.svm',
            self.PREFIX + self.grocery_name + '_train.svm')
        self.grocery.load()
        os.rename(
            self.PREFIX + self.grocery_name + '_train.svm',
            self.PREFIX + self.MODEL_DIR + self.grocery_name + '_train.svm')
        return self.grocery

    def predict(self, line, **kwargs):
        tag = self.grocery.predict(line)
        return tag

    def test(self, *args, **kwargs):
        test_src = str(kwargs["test_src"])
        test_result = self.grocery.test(test_src)
        print "Total Accuracy", test_result

        return test_result
示例#23
0
class Solution:
    def __init__(self):
        self.grocery = Grocery("sample")
        self.input = open("../data/train_raw.json", "r")

    def get_train_test(self):

        info_list = []
        n = 0
        train_set_file = open("train_file.json", "w")
        test_set_file = open("test_file.json", "w")
        for line in self.input:
            n += 1
            info_dict = json.loads(line.strip())
            label = info_dict['label']
            info = info_dict['title']
            url = info_dict['url']
            content = info_dict['content']
            info_list.append((label, info, url, content))

        num_array = np.random.permutation(n)
        info_array = np.array(info_list)

        n = len(info_array)
        train_set = info_array[num_array][:int(n * 0.7)]
        test_set = info_array[num_array][int(n * 0.7):]
        out_dict = {}
        for label, info, url, content in train_set:
            out_dict['url'] = url
            out_dict['label'] = label
            out_dict['info'] = info
            out_dict['content'] = content
            print >> train_set_file, json.dumps(out_dict)

        out_dict = {}
        for label, info, url, content in test_set:
            out_dict['url'] = url
            out_dict['label'] = label
            out_dict['info'] = info
            out_dict['content'] = content
            print >> test_set_file, json.dumps(out_dict)

        train_set_file.close()
        test_set_file.close()

    def train(self):

        #train_file = open("train_file.json", "r")
        #test_file = open("test_file.json", "r")
        train_file = open("../THUCTC/train_7_3.json", "r")
        test_file = open("../THUCTC/test_7_3.json", "r")
        train_src = []
        test_src = []
        for line in train_file:
            in_dict = json.loads(line.strip())
            try:
                #train_src.append((in_dict['label'], in_dict['info']+" "+in_dict['url'].strip().split("/")[-2]))
                train_src.append((in_dict['label'], in_dict['title']))
            except Exception as e:
                pass

        self.grocery.train(train_src)
        for line in test_file:
            in_dict = json.loads(line.strip())
            try:
                #test_src.append((in_dict['label'], in_dict['info']+" "+in_dict['url'].strip().split("/")[-2]))
                test_src.append((in_dict['label'], in_dict['title']))
            except Exception as e:
                pass
        print len(train_src)
        self.grocery.save()
        print self.grocery.test(test_src)
        """
示例#24
0
# -*- coding: utf-8 -*-
#
# 训练样本
# Author: Alex
# Created Time: 2016年12月29日 星期四 11时19分06秒

from tgrocery import Grocery

gr = Grocery('test')
train_file = "./data/train.csv"
gr.train(train_src=train_file)
gr.save()
示例#25
0
def main():
    # Get market_sentiment of word from NTUSD-Fin
    train_t = []
    train_s = []
    targetIn = {}
    targetDict = dict()
    with open('NTUSD-Fin/NTUSD_Fin_hashtag_v1.0.json', 'r',
              encoding='utf-8') as f:
        targetIn = json.load(f)
    N = len(targetIn)
    for i in range(N):
        word = "#" + targetIn[i]['token']
        targetDict[word] = targetIn[i]['market_sentiment']
        sg = str(GroupValue_s(str(targetDict[word] / 3.5)))
        train_s.append((sg, word))
    with open('NTUSD-Fin/NTUSD_Fin_word_v1.0.json', 'r',
              encoding='utf-8') as f:
        targetIn = json.load(f)
    N = len(targetIn)
    for i in range(N):
        word = targetIn[i]['token']
        targetDict[word] = targetIn[i]['market_sentiment']
        sg = str(GroupValue_s(str(targetDict[word] / 3.5)))
        train_s.append((sg, word))

    # Training File: Load data & Use tgrocery to train classification model
    TrainingFile = open('training_set.json', 'r')
    TrainingData = json.load(TrainingFile)
    TrainingFile.close()
    DataList = []
    grocery_t = Grocery("tweet")
    grocery_s = Grocery("snippet")
    for DataElement in TrainingData:
        tempt = DataManager()
        tempt.insertData(DataElement)
        tempt.group_t = GroupValue_t(tempt.sentiment)
        tempt.group_s = GroupValue_s(tempt.sentiment)
        line = re.sub("https?://[\w\-]+(\.[\w\-]+)+\S*", " ",
                      DataElement["tweet"])
        train_t.append((str(tempt.group_t), line))
        if isinstance(DataElement["snippet"], list):
            for line in DataElement["snippet"]:
                train_s.append((str(tempt.group_s), line))
        elif DataElement["snippet"] != "":
            train_s.append((str(tempt.group_s), DataElement["snippet"]))
        else:
            tempt.group_s = 0.0
        DataList.append(tempt)
    grocery_t.train(train_t + train_s)
    grocery_t.save()
    grocery_s.train(train_s)
    grocery_s.save()

    # Save training data created by WordScore() and GroupValue_*()
    # Data will be uesd for LinearRegression() in BOTH.py
    outfile = open('TG_train.txt', 'w', encoding='utf-8')
    dataScore = []
    dataSentiment = []
    for row in DataList:
        dataSentiment.append([float(row.sentiment)])
        a = WordScore(row.tweet, targetDict)
        b = WordScore(row.snippet, targetDict)
        c = row.group_t
        d = row.group_s
        dataScore.append([a, b, c, d])
        print(a, b, c, d, file=outfile)
    outfile.close()
    '''
	# Train linear regression model
	model = LinearRegression()
	model.fit(dataScore, dataSentiment)

	# Test for training data
	print('(train)R-squared: %.3f' % model.score(dataScore, dataSentiment)) #0.915
	predictions = model.predict(dataScore)
	rms = mean_squared_error(dataSentiment,predictions)
	print('RMSE: %.3f' % sqrt(rms)) #0.110
	print('MSE: %.3f' % rms) #0.012
	'''

    # Testing File: Load data & Use tgrocery classification model to predict
    TestingFile = open('test_set.json', 'r')
    TestingData = json.load(TestingFile)
    TestingFile.close()
    DataList = []
    new_grocery_t = Grocery('tweet')
    new_grocery_t.load()
    new_grocery_s = Grocery('snippet')
    new_grocery_s.load()
    for DataElement in TestingData:
        tempt = DataManager()
        tempt.insertData(DataElement)
        line = re.sub("https?://[\w\-]+(\.[\w\-]+)+\S*", " ",
                      DataElement["tweet"])
        tempt.group_t = float('{0}'.format(new_grocery_t.predict(line)))
        value = 0.0
        if isinstance(DataElement["snippet"], list):
            for line in DataElement["snippet"]:
                value = value + float('{0}'.format(
                    new_grocery_s.predict(line)))
            value = value / len(DataElement["snippet"])
        elif DataElement["snippet"] != "":
            value = float('{0}'.format(
                new_grocery_s.predict(DataElement["snippet"])))
        tempt.group_s = value
        DataList.append(tempt)

# Save testing data created by WordScore() and classification prediction
# Data will be uesd for LinearRegression() in BOTH.py
    outfile = open('TG_test.txt', 'w', encoding='utf-8')
    dataScore = []
    dataSentiment = []
    for row in DataList:
        dataSentiment.append([float(row.sentiment)])
        a = WordScore(row.tweet, targetDict)
        b = WordScore(row.snippet, targetDict)
        c = row.group_t
        d = row.group_s
        dataScore.append([a, b, c, d])
        print(a, b, c, d, file=outfile)
    outfile.close()
    '''
示例#26
0
def train(path,name):
    grocery = Grocery(name)   
    grocery.train(path)
    grocery.save()