forked from joferkington/houston_street_flooding
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
37 lines (28 loc) · 1.03 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import sklearn
import sklearn.pipeline
from sklearn import preprocessing, ensemble
import pandas as pd
import geopandas
data = pd.read_csv('features_and_class.csv', na_values=['--'])
features = list(data.columns)[1:-1]
obs_class = data['flooded']
pipeline = sklearn.pipeline.Pipeline([
('Replace NaNs', preprocessing.Imputer(strategy='mean')),
('Scale data', preprocessing.StandardScaler()),
('Classification', ensemble.RandomForestClassifier(
n_estimators=100,
n_jobs=-1,
)),
])
pipeline.fit(data[features].values, obs_class.values)
df = geopandas.read_file('prediction_features.geojson', driver='GeoJSON')
pred = pipeline.predict(df[features].values)
df['prediction'] = pred
df.to_file('prediction.geojson', driver='GeoJSON')
grid = pred.reshape(78, 69)
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.imshow(grid)
fig.savefig('prediction.png')