Пример #1
0
        def wrapper(db,feature_type, data):
            print "in func train_deco"
            # 訓練に使うサンプルのqueryを作る
            selector = data['selector']
            # 学習用関数に渡すパラメタ
            option = data['option']          
            # クラスのパターンが記述されていれば,それを使う
            class_remap = data['class_remap']
            
            # 処理前に入力内容を記録
            record = {}
            #if bool(selector):
            #    record['selector'] = copy.deepcopy(selector)
            #if bool(option):
            #    record['option'] = copy.deepcopy(option)
            #if bool(class_remap):
            #    record['class_remap'] = copy.deepcopy(class_remap)


            clf_id = generate_clf_id(algorithm,feature_type,data)
            
            prev_clf = db["classifiers"].find({"_id":clf_id})
            overwrite = False
            if data.has_key("overwrite") and data["overwrite"] in ["true",1,True,"True","TRUE"]:
                overwrite = True

            if prev_clf.count()>0 and not overwrite:
                return error_json("Classifier already exist. To overwrite it, set overwrite option to be true.")
                
            pca_components = 0
            if data.has_key("pca"):
                pca_components = int(data["pca"])
                  

            min_sample_count = 1
            if option.has_key('min_sample_count'):
                min_sample_count = int(option['min_sample_count'])
                del option['min_sample_count']
                
                
                  
                  
            # クラスへの分類
            samples = []
            sample_count = 0
            if not class_remap:
                samples = mongointerface.get_training_samples(db,feature_type,False,selector)
                sample_count = samples.count()
                samples = list(samples)
            else:
                # class_remap毎にサンプルを集める
                record["ignore(min_sample)"] = ["threshold:%d"%min_sample_count]
                for gt,pat in class_remap.items():
                    # patを解析する必要???
                    # または/hoge/という文字列であれば十分??要検証
                    # $regexが使えるなら,selectorにdeep_mergeでpatを突っ込むだけで良い?.
                    selector['ground_truth'] = re.compile(pat)
                    _samples = mongointerface.get_training_samples(db,feature_type,False,selector)
                    if min_sample_count >= _samples.count():
                        # 十分なサンプルがない場合はclassから削除して認識対象外とする
                        record["ignore(min_sample)"].append("%s: %d"%(gt,_samples.count()))
                        continue
                        #return error_json('No samples are hit by regular expression "%s"'%pat)
                    for s in _samples:
                        s['ground_truth'] = gt
                        samples.append(s)
                    sample_count += _samples.count()

            if min_sample_count >= sample_count:
                return error_json('Only %d samples are hit as training samples.'%sample_count)

            class_count = collections.defaultdict(int)



            for i,s in enumerate(samples):
                class_count[s['ground_truth']] += 1

            class_list = sorted(class_count.keys())
                
            # 特定のサンプルが多すぎる場合に間引く                
            if option.has_key('max_class_samples_by_median'):
                max_class_sample_num =  int(option['max_class_samples_by_median'] * numpy.median(numpy.array(class_count.values())))
                del option['max_class_samples_by_median']
                
                _samples = []
                for i,cls in enumerate(class_list):
                    cls_samples = [s for s in samples if s['ground_truth']==cls]
                    #print "num(samples from %s): %d"%(cls,len(cls_samples))
                    if class_count[cls] <= max_class_sample_num:
                        #print "%s: %f <= %f"%(cls,prob,max_prior)
                        _samples.extend(cls_samples)
                    else:
                        #print "%s: %f > %f"%(cls,prob,max_prior)
                        _samples.extend(random.sample(cls_samples,max_class_sample_num))
                        class_count[cls] = max_class_sample_num
                random.shuffle(_samples)
                samples = _samples
                print sample_count
                sample_count = len(samples)
                print sample_count


                           
            x = [[]] * sample_count
            y = [0] * sample_count
            for i,s in enumerate(samples):
                x[i] = s['ft']
                y[i] = s['ground_truth']


            if option.has_key('sparse'):
                x = lil_matrix(x).tocsr()
                del option['sparse']

            # クラスの「重み付け」
            #z = 0
            #for i,cls in enumerate(class_list):
            #    z += math.exp(class_count[cls])

            class_map = {}
            class_weight = {}
            for i,cls in enumerate(class_list):
                #print i
                #print cls
                class_map[cls] = i
                class_weight[i] = float(len(class_list) * (sample_count - class_count[cls])) / float(sample_count)

            print class_map
            #print class_map
            for i in range(len(y)):            
                #print i
                #print y[i]
                #print class_map[y[i]]
                y[i] = class_map[y[i]]

            if pca_components>0:
                pca = PCA(n_components=pca_components,copy=False)
                print "calc. PCA..."
                x = pca.fit_transform(x)
                print "done."
                record['pca'] = save_model(get_trained_model_filename(db.name, clf_id + "::pca"),pca)
                
            if type(x) == csr_matrix:
                record['sparse'] = True                
                
            # algorithmに応じた処理(func)を行う
            print "train..."
            clf=func(x,y,class_weight,option)
            print "done"


            # 結果を保存
            ## algorithmに依存する部分
            record['_id'] = clf_id
            event = {'_id':"train::" + record['_id']}
            
            print "pickle classifier..."
            record['clf'] = save_model(get_trained_model_filename(db.name, clf_id),clf)
            print "done."
            record['class_name2id'] = class_map
            class_map_inv = {str(v):k for k, v in class_map.items()}
            record['class_id2name'] = class_map_inv
            record['class_count'] = class_count
            try:
                db["classifiers"].replace_one({"_id":clf_id},record,True)
            except:
                print sys.exc_info()
                return error_json(sys.exc_info()[1])

            result = success_json()
            result['event'] = event
            print "return result successfully"
            #print record['clf']
            #print record['pca']
            return result
Пример #2
0
def cross_validation(db, json_data_s, feature_type, algorithm, fold_num):
    print("function: cross_validation")
    data = json.loads(json_data_s)
    init_data(data)

    cv_group_head = "__cross_validation"    
    # disband all previously taged cross_validation_groups
    for i in range(0,fold_num):
        group_name = generate_group_name(cv_group_head, i)
        mongointerface.disband(db, feature_type, {'group': group_name})
    mongointerface.disband(db, feature_type, {'group': cv_group_head})
        
    collections = db[feature_type]
    selector = data['selector']
    data['selector']['ground_truth'] = {"$exists": True}
    samples = collections.find(selector)

    # group samples into N groups randomly

    samples_count = samples.count()
    if samples_count == 0:
        return error_json("ERROR: no samples are hit.")

    group_assignment = []
    remainder = samples_count % fold_num
    quotient = int(samples_count / fold_num)
    for i in range(0,fold_num):
        n = quotient
        if i < remainder:
            n = n+1
        print("group_count[%02d] = %d" % (i,n))
        group_assignment += [generate_group_name(cv_group_head, i)] * n
    random.shuffle(group_assignment)
                
    # grouping samples into N group
    for i in range(samples_count):
        s = samples[i]
        group_name = group_assignment[i]
        #print group_name

        groups = s['group']
        if not group_name in groups:
            groups = mongointerface.ensure_list(groups)
            groups.append(group_name)
            groups.append(cv_group_head)
            _id = s['_id']
            collections.update_one({"_id":_id},{"$set":{'group':groups}})

    mod = __import__(algorithm+'.classifier', fromlist=['.'])

    #print 'train and evaluation'
    # evaluate each group by trained classifiers    
    confusion_matrices = []
    # train, predict, and evaluate N classifiers
    for i in range(0,fold_num):
        ## train ##
        exclude_group = generate_group_name(cv_group_head, i)
        #print exclude_group
        _data = copy.deepcopy(data)
        _data['selector'] = {'group':{'$not':{'$all':[exclude_group]},'$all':[cv_group_head]},'ground_truth':{"$exists": True}}
        _data['overwrite'] = True
        _data['name'] = exclude_group
        #print _data
        result = mod.train(db,feature_type,_data)
        if result['status'] != 'success':
            return result
            
        ## predict ##
        selector = {'group':{'$all':[exclude_group]}}        
        group_samples = mongointerface.get_training_samples(db,feature_type,False,selector)
        for s in group_samples:
            result = mod.predict(db,feature_type, Sample(s), _data)
            if result['status'] != 'success':
                return result
        _data['selector'] = selector
        ## evaluate ##

        result = mongointerface.evaluate(db, feature_type, _data, algorithm)
        if result['status'] != 'success':
            return result
        confusion_matrices.append(result['confusion_matrix'])
    
    cmat = None
    for m in confusion_matrices:
        if bool(cmat):
            cmat = merge_confusion_matrix(cmat,json.loads(m))
        else:
            cmat = json.loads(m)
    result = success_json()
    result['confusion_matrix'] = cmat
    clf_id = generate_clf_id(algorithm,feature_type,data)
    result['event'] = {'_id':"cross_validation::" + clf_id}
    return result