def earnn( experiment="one_month_forecast", include_pred_month=True, surrounding_pixels=None, pretrained=True, ignore_vars=None, ): data_path = get_data_path() if not pretrained: predictor = EARecurrentNetwork( hidden_size=128, data_folder=data_path, experiment=experiment, include_pred_month=include_pred_month, surrounding_pixels=surrounding_pixels, ignore_vars=ignore_vars, ) predictor.train(num_epochs=50, early_stopping=5) predictor.evaluate(save_preds=True) predictor.save_model() else: predictor = load_model(data_path / f"models/{experiment}/ealstm/model.pt") test_file = data_path / f"features/{experiment}/test/2018_3" assert test_file.exists() all_explanations_for_file(test_file, predictor, batch_size=100)
def earnn( experiment="one_month_forecast", include_pred_month=True, surrounding_pixels=None, pretrained=False, explain=False, static="features", ignore_vars=None, num_epochs=50, early_stopping=5, static_embedding_size=10, hidden_size=128, predict_delta=False, spatial_mask=None, include_latlons=False, normalize_y=True, include_prev_y=True, include_yearly_aggs=True, # new clear_nans=True, weight_observations=False, pred_month_static=False, ): data_path = get_data_path() if not pretrained: predictor = EARecurrentNetwork( hidden_size=hidden_size, data_folder=data_path, experiment=experiment, include_pred_month=include_pred_month, surrounding_pixels=surrounding_pixels, static=static, static_embedding_size=static_embedding_size, ignore_vars=ignore_vars, predict_delta=predict_delta, spatial_mask=spatial_mask, include_latlons=include_latlons, normalize_y=normalize_y, include_prev_y=include_prev_y, include_yearly_aggs=include_yearly_aggs, clear_nans=clear_nans, weight_observations=weight_observations, pred_month_static=pred_month_static, ) predictor.train(num_epochs=num_epochs, early_stopping=early_stopping) predictor.evaluate(save_preds=True) predictor.save_model() else: predictor = load_model(data_path / f"models/{experiment}/ealstm/model.pt") if explain: test_file = data_path / f"features/{experiment}/test/2018_3" assert test_file.exists() all_explanations_for_file(test_file, predictor, batch_size=100)