def main():
    dpark_ctx = dpark.DparkContext('mesos')
    if SET is 'train':
        print 'for training data:'
        assert os.path.isdir(DATA_PATH), 'Data dir not exist.'
        fn_list = filter(lambda x: not x.startswith('.'),
                         os.listdir(DATA_PATH))
        fn_list = map(lambda x: os.path.join(DATA_PATH, x), fn_list)
        n_data = dpark_ctx.makeRDD(fn_list,
                                   100).map(tr_data_len
                                            #tr_data_len_rewrt
                                            ).reduce(lambda x, y: x + y)
        print 'Total number : %i' % n_data

    if SET is 'test':
        print 'for testing data:'
        n_data = dpark_ctx.accumulator(0)

        def exam_each(data):
            assert len(data) == SIZE and all(len(dat) == DIM for dat in data)
            n_data.add(1)

        bool_rdd = dpark_ctx.beansdb(DATA_PATH).map(
            lambda (key, (val, _, __)): val).foreach(exam_each)
        print "Total number: %i" % n_data.value
示例#2
0
文件: wc.py 项目: swinsey/DparkIntro
def main():
    dc = dpark.DparkContext()
    options, args = dpark.optParser.parse_args()
    file_path = args[0]

    data = dc.textFile(file_path)
    wc = data.flatMap(parse_words)\
             .reduceByKey(lambda x, y: x + y)\
             .top(10, key=lambda x: x[1])
    print wc
def main():
    dpark_ctx = dpark.DparkContext('mesos')

    # QA for overwrite case
    flag = False
    if len(os.listdir(DST_PATH)): # not empty
        while 1:
            print "Are you sure to overwrite the spectrum data in %s? [Y/n]" % DST_PATH
            key = raw_input()
            if key is 'Y':
                flag = True
                print 'Will overwrite immediately.'
                os.system('rm -rf ' + DST_PATH +  '/.[^.]') # TODO
                break
            elif key is 'n':
                flag = False
                print 'Will not overwrite it.'
                break
            else:
                print 'Please type [Y/n].'
    else:
        flag = True

    # Start extracting 
    if flag:
        if SET is 'train':
            # for paracel IO, provide Textfiles
            sids_with_label = get_song_ids_with_label_train()
            dpark_ctx.makeRDD(
                        sids_with_label, 100
                    ).map(
                        map_song_with_label_tr
                    ).filter(
                    lambda x: x != None
                    ).saveAsTextFile(
                        DST_PATH
                    )
        elif SET is 'test':
            # for python script of test, provide beansdb
            sids_with_label = get_song_ids_with_label_test(40000)
            dpark_ctx.makeRDD(
                        sids_with_label, 50
                    ).map(
                        map_song_with_label_te
                    ).filter(
                    lambda x: x != None
                    ).saveAsBeansdb(
                        DST_PATH
                    )
        else:
            print 'Not supported this data set'
            sys.exit(-1)

        print 'spec extracting done'
示例#4
0
def main():
    dpark_ctx = dpark.DparkContext()
    vec_rdd = dpark_ctx.beansdb(
        BDB_DIR
    ).map(
        lambda (key, (val, _, __)): val
    )

    data = vec_rdd.collect()
    print len(data)
    print len(data[0])
def main(argv):
    assert len(argv) == 2 or len(argv) == 3, "<usage> ./check_nan_te.py dirname rewrt_flag"
    assert os.path.isdir(argv[1])
    print 'Check nan in testing dir: %s' % argv[1]
    rewrt_flag = bool(argv[2]) if len(argv) == 3 else False # default
    dpark_ctx = dpark.DparkContext('mesos')
    
    if rewrt_flag:
        print 'Will rewrite data, purging nan.'
        def exam_nan_purge(data):
            np_dat = np.array(data[1][0])
            n_nan = len(np.where(np_dat != np_dat)[0])
            if n_nan != 0:
                print "Warning, %i numbers nan" % n_nan
                return None
            else:
                return data[0], data[1][0]
        dpark_ctx.beansdb(
                argv[1]
                ).map(
                    exam_nan_purge
                ).filter(
                    lambda x: x != None         
                ).saveAsBeansdb(
                    argv[1]    
                )

    else:
        print 'Will NOT rewrite data.'
        def exam_nan(data):
            np_dat = np.array(data)
            n_nan = len(np.where(np_dat != np_dat)[0])
            if n_nan != 0:
                print "Warning, %i numbers nan" % n_nan
        dpark_ctx.beansdb(
                argv[1]
                ).map(
                    lambda (key, (val, _, __)): val
                ).foreach(
                    exam_nan         
                )
    print 'done'
def main():
    dpark_ctx = dpark.DparkContext('mesos')
    if SET is 'train':
        print 'for training data:'
        assert os.path.isdir(DATA_PATH), 'Data dir not exist.'
        fn_list = filter(lambda x: not x.startswith('.'),
                         os.listdir(DATA_PATH))
        fn_list = map(lambda x: os.path.join(DATA_PATH, x), fn_list)
        n_data = dpark_ctx.makeRDD(fn_list,
                                   100).map(tr_data_len
                                            #tr_data_len_rewrt
                                            ).reduce(lambda x, y: x + y)
        print 'Total number : %i' % n_data

    if SET is 'test':
        print 'for testing data:'
        vec_rdd = dpark_ctx.beansdb(DATA_PATH).map(
            lambda (key, (val, _, __)): val)  #.sample(1/100.)
        data = vec_rdd.collect()
        assert all(len(dat) == DIM * SIZE for dat in data), \
                'error of data dimension.'
        print "Total number: %i" % len(data)
示例#7
0
def main():
    dpark_ctx = dpark.DparkContext('mesos')
    assert os.path.isdir(BASE_PATH) and os.path.isdir(MODEL_PATH)

    # Read the weights and bias of SDAE from MODEL_PATH
    W, b = load_ae(MODEL_PATH)

    # SVM
    print 'Will adopt layer No. %i' % FEA_LAYER
    lyr = W.keys()
    lyr.sort()
    lyr_last = lyr[FEA_LAYER]
    lyr = lyr[:FEA_LAYER]
    lyr.append(lyr_last)

    # SVM training and validating data
    svm_data_tr, svm_label_tr, svm_data_va, svm_label_va = \
            load_svm_tr_va_data(TRAIN_DATA_PATH, dpark_ctx, (lyr, W, b))
    # SVM testing data
    svm_data_te, svm_label_te = load_svm_te_data(TEST_DATA_PATH, dpark_ctx,
                                                 (lyr, W, b))

    # Process data, view the GID distribution in tr or te sets
    print 'Processing data here.'
    GID_adjust = range(len(GID))  # GID adjust to 0-13
    print '=' * 100
    print 'Training data distributions:'
    tr_hist = data_dist(GID_adjust, svm_label_tr)
    print '=' * 100
    print 'Testing data distributions:'
    te_hist = data_dist(GID_adjust, svm_label_te)
    print '=' * 100
    print 'Validation data distributions:'
    va_hist = data_dist(GID_adjust, svm_label_va)

    # Binary Test on two classes in validation set
    if BINARY_TEST:
        print 'Doing a binary class test using validation set'
        lbl1 = 7
        lbl2 = 11
        svm_data_va_bin = []
        svm_label_va_bin = []
        for lbl_elem, data_elem in zip(svm_label_va, svm_data_va):
            if lbl_elem == lbl1:
                svm_label_va_bin.append(-1)
                svm_data_va_bin.append(data_elem)
            elif lbl_elem == lbl2:
                svm_label_va_bin.append(1)
                svm_data_va_bin.append(data_elem)
            else:
                pass
        print 'Binary classes data was prepared.'
        for svm_c in [
                0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000, 100000, 1e6, 1e7,
                1e8
        ]:
            svm_opt = '-c ' + str(svm_c) + ' -w-1 3 -w1 2 -v 5 -q'
            svm_model = svm.svm_train(svm_label_va_bin, svm_data_va_bin,
                                      svm_opt)

    # Cross Validation on whole validation set
    elif CROSS_VALIDATION:
        print 'SVM model starts cross validating.'
        for svm_c in [0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000]:
            svm_opt = '-c ' + str(svm_c) + ' '
            for gid_elem, va_hist_elem in zip(GID_adjust, va_hist):
                wgt_tmp = max(va_hist) / float(va_hist_elem)
                '''
                if wgt_tmp < 3.0:
                    wgt = 1
                elif wgt_tmp < 10:
                    wgt = 4
                elif wgt_tmp < 40:
                    wgt = 16
                else:
                    wgt = 32
                '''
                if wgt_tmp < 10.0:
                    wgt = int(wgt_tmp)
                elif wgt_tmp < 40:
                    wgt = 16
                else:
                    wgt = 32
                svm_opt += ('-w' + str(gid_elem) + ' ' + str(wgt) + ' ')
            svm_opt += '-v 5 -q'
            print svm_opt
            svm_model = svm.svm_train(svm_label_va, svm_data_va, svm_opt)

    # SVM running on whole Training / Testing sets
    else:
        fn_svm = 'svm_model_c1_wgt'
        if SAVE_OR_LOAD:  # True
            print 'SVM model starts training.'
            svm_opt = '-c 1 '
            for gid_elem, tr_hist_elem in zip(GID_adjust, tr_hist):
                wgt_tmp = max(tr_hist) / float(tr_hist_elem)
                if wgt_tmp < 3.0:
                    wgt = 1
                elif wgt_tmp < 10:
                    wgt = 2
                elif wgt_tmp < 40:
                    wgt = 4
                else:
                    wgt = 8
                svm_opt += ('-w' + str(gid_elem) + ' ' + str(wgt) + ' ')
            print svm_opt
            svm_model = svm.svm_train(svm_label_tr, svm_data_tr, svm_opt)
            # save SVM model
            svm.svm_save_model(fn_svm, svm_model)
        else:  # False
            print 'SVM model loading.'
            # load SVM model
            svm_model = svm.svm_load_model(fn_svm)
        print 'SVM model training or loading done'
        p_label, p_acc, p_val = svm.svm_predict(svm_label_te, svm_data_te,
                                                svm_model)
        fid = open('res_tmp.pkl', 'wb')
        pickle.dump((p_label, p_acc, p_val), fid)
        fid.close()