def plot_components(x):
  draw_components([x[0:3, :2],
                   x[3:6, :2],
                   x[6:9, :2],
                   x[9:12, :2],
                   x[12:13, :2],
                   x[13:14, :2],
                   x[14:17, :2],
                   x[17:18, :2],
                   x[18:19, :2],
                   x[19:20, :2],
                   x[20:21, :2]],
                   ['Ag:0.237 v:0.0525',
                    'Ag:0.237 v:0.0593',
                    'Ag:0.237 v:0.0773',
                    'Ag:0.237 v:0.0844',
                    'Ag:0.239 v:0.0659',
                    'Ag:0.243 v:0.0659',
                    'Ag:0.239 v:0.0791',
                    'Ag:0.239 v:0.0525',
                    'Ag:0.243 v:0.0525',
                    'Ag:0.237 v:0.0914',
                    'Ag:0.237 v:0.0512'])
def plot_components(x, y, n_comps, linker_model, verbose=2):
  prim_basis = PrimitiveBasis(n_states=3, domain=[0, 2])
  model = MKSHomogenizationModel(basis=prim_basis,
                                 property_linker=linker_model)
  model.n_components = 5
  model.fit(x,y,periodic_axes=[0,1])

  print model.property_linker.coef_
  draw_components([model.reduced_fit_data[0:3, :2],
                   model.reduced_fit_data[3:6, :2],
                   model.reduced_fit_data[6:9, :2],
                   model.reduced_fit_data[9:11, :2],
                   model.reduced_fit_data[11:14, :2],
                   model.reduced_fit_data[14:16, :2],
                   model.reduced_fit_data[16:17, :2],
                   model.reduced_fit_data[17:18, :2]],
                   ['Ag:0.237	Cu:0.141	v:0.0525',
                    'Ag:0.237	Cu:0.141	v:0.0593',
                    'Ag:0.237	Cu:0.141	v:0.0773',
                    'Ag:0.237	Cu:0.141	v:0.0844',
                    'Ag:0.239	Cu:0.138	v:0.0791',
                    'Ag:0.239	Cu:0.138	v:0.0525',
                    'Ag:0.237	Cu:0.141	v:0.0914',
                    'Ag:0.237	Cu:0.141	v:0.0512'])
  model = MKSHomogenizationModel(basis=prim_basis,
                                 compute_correlations=False)

  #model.fit(x_corr, y, periodic_axes=[0, 1])
  # set up parameters to optimize
  params_to_tune = {'degree': np.arange(1, 4), 'n_components': np.arange(1, 8)}
  fit_params = {'size':x_corr_flat.shape, 'periodic_axes': [0, 1]}
  loo_cv = LeaveOneOut(samples)
  gs = GridSearchCV(model, params_to_tune, cv=loo_cv, n_jobs=6, fit_params=fit_params).fit(x_corr_flat, y)

  # Manual fit
  #model.fit(x_corr, y, periodic_axes=[0, 1])
  #print model.reduced_fit_data

  # Draw the plot containing the PCA variance accumulation
  #draw_component_variance(model.dimension_reducer.explained_variance_ratio_)

  draw_components([model.reduced_fit_data[0:3, :2],
                   model.reduced_fit_data[3:6, :2],
                   model.reduced_fit_data[6:9, :2],
                   model.reduced_fit_data[9:11, :2]], ['0.0525', '0.0593', '0.0773','0.0844'])
  print('Order of Polynomial'), (gs.best_estimator_.degree)
  print('Number of Components'), (gs.best_estimator_.n_components)
  print('R-squared Value'), (gs.score(X_test, y_test))

  #draw_components([model.reduced_fit_data[0:3, :2],
  #                 model.reduced_fit_data[3:6, :2],
  #                 model.reduced_fit_data[6:9, :2],
  #                 model.reduced_fit_data[9:11, :2],
  #                 model.reduced_fit_data[11:, :2]], ['0.0525', '0.0593', '0.0773','0.0844', '>0.6'])
gs = GridSearchCV(model, params_to_tune, cv=3, n_jobs=3,
                  fit_params=fit_params).fit(data_train, stress_train)'''

# model = gs.best_estimator_
model.n_components = 4
model.degree = 2
print('Components'), (model.n_components)
print('Polynomail Order'), (model.degree)

# Fit data to model

model.fit(dataset, stresses, periodic_axes=[0, 1])
shapes = (data_test.shape[0],) + (dataset.shape[1:])

print shapes
data_test = data_test.reshape(shapes)

stress_predict = model.predict(data_test, periodic_axes=[0, 1])
labels = 'Long X', 'Short X', 'Long Y', 'Short Y'

# Draw PCA plot
draw_components([model.reduced_fit_data[:, :2],
                model.reduced_predict_data[:, :2]],
                ['Training Data', 'Testing Data'])


# # Draw goodness of fit
fit_data = np.array([stresses, model.predict(dataset, periodic_axes=[0, 1])])
pred_data = np.array([stress_test, stress_predict])
draw_goodness_of_fit(fit_data, pred_data, ['Training Data', 'Testing Data'])