コード例 #1
0
 def __init__(self,
              bins,
              data_path,
              calibration_features,
              tt_lstm_config_path,
              soccer_data_store_dir,
              apply_old,
              apply_difference,
              focus_actions_list=[]):
     self.bins = bins
     # self.bins_names = bins.keys()
     self.apply_old = apply_old
     self.apply_difference = apply_difference
     self.data_path = data_path
     self.calibration_features = calibration_features
     if self.apply_difference:
         self.calibration_values_all_dict = {
             'all': {
                 'cali_sum': [0],
                 'model_sum': [0],
                 'number': 0
             }
         }
     else:
         self.calibration_values_all_dict = {
             'all': {
                 'cali_sum': [0, 0, 0],
                 'model_sum': [0, 0, 0],
                 'number': 0
             }
         }
     self.soccer_data_store_dir = soccer_data_store_dir
     self.tt_lstm_config = TTLSTMCongfig.load(tt_lstm_config_path)
     self.focus_actions_list = focus_actions_list
     if self.apply_difference:
         self.save_calibration_dir = './calibration_results/difference-calibration-{0}-{1}.txt'. \
             format(str(self.focus_actions_list), datetime.date.today().strftime("%Y%B%d"))
     else:
         self.save_calibration_dir = './calibration_results/calibration-{0}-{1}.txt'. \
             format(str(self.focus_actions_list), datetime.date.today().strftime("%Y%B%d"))
     self.save_calibration_file = open(self.save_calibration_dir, 'w')
     if apply_difference:
         self.teams = ['home-away']
     else:
         self.teams = ['home', 'away', 'end']
コード例 #2
0
from td_three_prediction_two_tower_lstm_v_correct_dir.support.data_processing_tools import normalize_data
from td_three_prediction_two_tower_lstm_v_correct_dir.nn_structure.td_tt_lstm import td_prediction_tt_embed
from td_three_prediction_two_tower_lstm_v_correct_dir.support.plot_tools import compute_game_values, read_plot_model

if __name__ == '__main__':
    data_store_dir = "/cs/oschulte/Galen/Hockey-data-entire/Hybrid-RNN-Hockey-Training-All-feature5-scale" \
                     "-neg_reward_v_correct__length-dynamic/"
    data_path = "/cs/oschulte/Galen/Hockey-data-entire/Hockey-Match-All-data/"
    tt_lstm_config_path = '../icehockey-config.yaml'
    home_team = 'Penguins'
    away_team = 'Canadiens'
    target_game_id = str(1403)
    dir_all = os.listdir(data_path)
    game_name_dir = find_game_dir(dir_all, data_path, target_game_id)

    tt_lstm_config = TTLSTMCongfig.load(tt_lstm_config_path)
    learning_rate = tt_lstm_config.learn.learning_rate
    if learning_rate == 1e-5:
        learning_rate_write = 5
    elif learning_rate == 1e-4:
        learning_rate_write = 4

    sess_nn = tf.InteractiveSession()

    model_nn = td_prediction_tt_embed(
        feature_number=tt_lstm_config.learn.feature_number,
        home_h_size=tt_lstm_config.Arch.HomeTower.home_h_size,
        away_h_size=tt_lstm_config.Arch.AwayTower.away_h_size,
        max_trace_length=tt_lstm_config.learn.max_trace_length,
        learning_rate=tt_lstm_config.learn.learning_rate,
        embed_size=tt_lstm_config.learn.embed_size,