コード例 #1
0
ファイル: test_sampling.py プロジェクト: DanielY1783/te_ml
def test_stratified_sample_5():
    dict = {
        "col 1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
        "label": [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
    }
    df = pd.DataFrame(data=dict)
    with pytest.raises(Exception):
        sampled_df = stratified_sample(df, fraction=2)
コード例 #2
0
ファイル: test_sampling.py プロジェクト: DanielY1783/te_ml
def test_stratified_sample_3():
    dict = {
        "col 1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
        "label": [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
    }
    df = pd.DataFrame(data=dict)
    sampled_df = stratified_sample(df, fraction=1)
    assert sampled_df.shape == df.shape
コード例 #3
0
ファイル: test_sampling.py プロジェクト: DanielY1783/te_ml
def test_stratified_sample_4():
    dict = {
        "col 1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
        "label": [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
    }
    df = pd.DataFrame(data=dict)
    sampled_df = stratified_sample(df, fraction=0.0001)
    assert len(sampled_df.loc[sampled_df.loc[:, "label"] == 0]) == 0
    assert len(sampled_df.loc[sampled_df.loc[:, "label"] == 1]) == 0
コード例 #4
0
def get_samples():

    #Check if data is cached
    mc = memcache.Client(['127.0.0.1:11211'], debug=0)

    if mc.get('random_samples') is not None:
        random_samples = mc.get('random_samples')
        stratified_samples = mc.get('stratified_samples')
        geo_vals = mc.get('geo_vals')
        
    else:
        geo_vals = {}
        
        row_list = []
        col_dict = {}
        with open(files.csv_file, 'rt') as f:
            reader = csv.reader(f)
            firstRow = next(reader)


            print(firstRow)


            for i,rowName in enumerate(firstRow[:-1]):
                col_dict[rowName] = []

            for row in reader:
                row_list.append(row[:-1])
                
                if row[20] not in geo_vals.keys():
                    geo_vals[row[20]] = {}
                    geo_vals[row[20]]["male"] = 0
                    geo_vals[row[20]]["female"] = 0
                    geo_vals[row[20]]["a1"] = 0
                    geo_vals[row[20]]["a2"] = 0
                    geo_vals[row[20]]["a3"] = 0
                    geo_vals[row[20]]["day"] = {}



                if row[2] == '1':
                    geo_vals[row[20]]["male"] += 1
                elif row[2] == '2':
                    geo_vals[row[20]]["female"] += 1
                

                if int(row[3]) >= 18 and int(row[3]) <= 30:
                    geo_vals[row[20]]['a1'] += 1
                if int(row[3]) >= 31 and int(row[3]) <= 40:
                    geo_vals[row[20]]['a2'] += 1
                if int(row[3]) >= 42 and int(row[3]) <= 50:
                    geo_vals[row[20]]['a3'] += 1

                if row[12] not in geo_vals[row[20]]["day"].keys():
                        geo_vals[row[20]]["day"][row[12]] = 0


                geo_vals[row[20]]["day"][row[12]]  += 1


                for i,rowName in enumerate(firstRow[:-1]):
                    col_dict[rowName].append(row[i])



        # Perform random sampling
        random_samples = random_sample(row_list, constants.sample_fraction*len(row_list))
        print("Random sampling - Num_sampled : "+str(len(random_samples)))


        print(random_samples[0])
        #Normalize data
        random_samples = preprocessing.scale(random_samples)


        print(random_samples[0])
        #Find optimal k value with elbow
        elbow_vals = plot_elbow(row_list)



        #Perform stratified sampling


        stratified_samples = stratified_sample(row_list)
        #Normalize data

        print("STTTTTTT")
        stratified_samples = preprocessing.scale(stratified_samples)

        print("STTTTTTT2")

        mc.set('first_row',firstRow)
        mc.set('random_samples',random_samples)
        mc.set('stratified_samples',stratified_samples)
        mc.set('elbow_vals',elbow_vals)
        mc.set('geo_vals',geo_vals)

    print(random_samples[0])
    return random_samples,stratified_samples