def train_encoder_model(model, sess, config, input_data, target_data, sample_trace_length_t0, train_mask,
                        pretrain_flag=False):
    if config.Learn.apply_lstm:
        feed_dict = {model.output_ph: target_data,
                     model.input_ph: input_data[:, :, config.Arch.Encoder.output_dim:],
                     model.trace_lengths_ph: sample_trace_length_t0
                     }
    else:
        feed_dict = {model.output_ph: target_data,
                     model.input_ph: input_data[:, config.Arch.Encoder.output_dim:],
                     }

    [
        output_x,
        likelihood_loss,
        _
    ] = sess.run([
        model.player_prediction,
        # model.cost,
        model.likelihood_loss,
        model.train_encoder_op],
        feed_dict=feed_dict
    )

    acc = compute_acc(target_label=target_data, output_prob=output_x)
    # print acc
    # if cost_out > 0.0001: # TODO: we still need to consider how to define convergence
    #     converge_flag = False
    # print('kl loss is {0} while ll lost is {1}'.format(str(kl_loss), str(likelihood_loss)))
    return likelihood_loss
def train_model(model, sess, config, input_data, target_data, trace_lengths,
                terminal):
    [output_prob, _] = sess.run(
        [model.read_out, model.train_op],
        feed_dict={
            model.rnn_input_ph: input_data,
            model.y_ph: target_data,
            model.trace_lengths_ph: trace_lengths
        })
    acc = compute_acc(output_prob, target_data, if_print=False)
Ejemplo n.º 3
0
def train_prediction(model, sess, config, input_data, target_data,
                     trace_lengths, selection_matrix):
    [output_prob, _] = sess.run(
        [model.action_pred_output, model.train_action_pred_op],
        feed_dict={
            model.selection_matrix_ph: selection_matrix,
            model.input_data_ph: input_data,
            model.action_pred_target_ph: target_data,
            model.trace_length_ph: trace_lengths
        })
    acc = compute_acc(output_prob, target_data, if_print=False)
    pass
def train_prediction(model, sess, config, input_data, target_data,
                     trace_lengths, selection_matrix):
    [output_prob, _] = sess.run(
        [model.action_pred_output, model.train_action_pred_op],
        feed_dict={
            model.selection_matrix_ph: selection_matrix,
            model.input_data_ph: input_data,
            model.action_pred_target_ph: target_data,
            model.trace_length_ph: trace_lengths
        })
    TP, TN, FP, FN, acc, ll, auc = compute_acc(target_data,
                                               output_prob,
                                               if_binary_result=True,
                                               if_add_ll=True,
                                               if_print=False)
    precision = float(TP) / (TP + FP) if TP > 0 else 0
    recall = float(TP) / (TP + FN) if FN > 0 else None
    if TRAINING_ITERATIONS % 20 == 0:
        print("Prediction acc is {0} with precision {1} and recall {2}".format(
            acc, precision, recall))
def train_prediction(model, sess, config, sample_pred_input_data,
                     sample_pred_target_data, sample_pred_trace_lengths):
    if config.Learn.apply_lstm:
        feed_dict = {
            model.input_ph: sample_pred_input_data[:, :, config.Arch.Encoder.output_dim:],
            model.predict_target_ph: sample_pred_target_data,
            model.trace_lengths_ph: sample_pred_trace_lengths}
    else:
        feed_dict = {model.input_ph: sample_pred_input_data[:, config.Arch.Encoder.output_dim:],
                     model.predict_target_ph: sample_pred_target_data
                     }

    [
        predict_output,
        _
    ] = sess.run([
        model.prediction_prob,
        model.train_prediction_op],
        feed_dict=feed_dict
    )

    acc = compute_acc(predict_output, sample_pred_target_data, if_print=False)
    pass
def validation_model(testing_dir_games_all, data_store, config, sess, model,
                     player_id_cluster_dir, source_data_store_dir):
    output_decoder_all = None
    target_data_all = None
    selection_matrix_all = None
    print('validating model')
    game_number = 0
    for dir_game in testing_dir_games_all:

        if dir_game == '.DS_Store':
            continue
        game_number += 1
        state_trace_length, state_input, reward, action, team_id, player_index = get_icehockey_game_data(
            data_store=data_store,
            dir_game=dir_game,
            config=config,
            player_id_cluster_dir=player_id_cluster_dir)
        action_seq = transfer2seq(data=action,
                                  trace_length=state_trace_length,
                                  max_length=config.Learn.max_seq_length)
        team_id_seq = transfer2seq(data=team_id,
                                   trace_length=state_trace_length,
                                   max_length=config.Learn.max_seq_length)
        if config.Learn.predict_target == 'ActionGoal' or config.Learn.predict_target == 'Action':
            player_index_seq = transfer2seq(
                data=player_index,
                trace_length=state_trace_length,
                max_length=config.Learn.max_seq_length)
        else:
            player_index_seq = player_index
        if config.Learn.predict_target == 'ActionGoal':
            actions_all = read_feature_within_events(
                directory=dir_game,
                data_path=source_data_store_dir,
                feature_name='name')
            next_goal_label = []
            data_length = state_trace_length.shape[0]

            win_info = []
            new_reward = []
            new_action_seq = []
            new_state_input = []
            new_state_trace_length = []
            new_team_id_seq = []
            for action_index in range(0, data_length):
                action = actions_all[action_index]
                if 'shot' in action:
                    new_reward.append(reward[action_index])
                    new_action_seq.append(action_seq[action_index])
                    new_state_input.append(state_input[action_index])
                    new_state_trace_length.append(
                        state_trace_length[action_index])
                    new_team_id_seq.append(team_id_seq[action_index])
                    if action_index + 1 == data_length:
                        continue
                    if actions_all[action_index + 1] == 'goal':
                        # print(actions_all[action_index+1])
                        next_goal_label.append([1, 0])
                    else:
                        # print(actions_all[action_index + 1])
                        next_goal_label.append([0, 1])
            reward = np.asarray(new_reward)
            action_seq = np.asarray(new_action_seq)
            state_input = np.asarray(new_state_input)
            state_trace_length = np.asarray(new_state_trace_length)
            team_id_seq = np.asarray(new_team_id_seq)
            win_info = np.asarray(next_goal_label)
        elif config.Learn.predict_target == 'Action':
            win_info = action[1:, :]
            reward = reward[:-1]
            action_seq = action_seq[:-1, :, :]
            state_input = state_input[:-1, :, :]
            state_trace_length = state_trace_length[:-1]
            team_id_seq = team_id_seq[:-1, :, :]
        else:
            win_info = None
        # print ("\n training file" + str(dir_game))
        # reward_count = sum(reward)
        # print ("reward number" + str(reward_count))
        if len(state_input) != len(reward) or len(state_trace_length) != len(
                reward):
            raise Exception('state length does not equal to reward length')

        train_len = len(state_input)
        train_number = 0
        s_t0 = state_input[train_number]
        train_number += 1

        while True:
            # try:
            batch_return, \
            train_number, \
            s_tl, \
            print_flag = get_together_training_batch(s_t0=s_t0,
                                                     state_input=state_input,
                                                     reward=reward,
                                                     player_index=player_index_seq,
                                                     train_number=train_number,
                                                     train_len=train_len,
                                                     state_trace_length=state_trace_length,
                                                     action=action_seq,
                                                     team_id=team_id_seq,
                                                     win_info=win_info,
                                                     config=config)

            # get the batch variables
            # s_t0, s_t1, r_t0_combine, s_length_t0, s_length_t1, action_id_t0, action_id_t1, team_id_t0,
            #                      team_id_t1, 0, 0
            s_t0_batch = [d[0] for d in batch_return]
            s_t1_batch = [d[1] for d in batch_return]
            r_t_batch = [d[2] for d in batch_return]
            trace_t0_batch = [d[3] for d in batch_return]
            trace_t1_batch = [d[4] for d in batch_return]
            action_id_t0 = [d[5] for d in batch_return]
            action_id_t1 = [d[6] for d in batch_return]
            team_id_t0_batch = [d[7] for d in batch_return]
            team_id_t1_batch = [d[8] for d in batch_return]
            player_id_t0_batch = [d[9] for d in batch_return]
            player_id_t1_batch = [d[10] for d in batch_return]
            win_info_t0_batch = [d[11] for d in batch_return]

            trace_lengths = trace_t0_batch

            if config.Learn.predict_target == 'ActionGoal':
                target_data = np.asarray(win_info_t0_batch)
            elif config.Learn.predict_target == 'Action':
                target_data = np.asarray(win_info_t0_batch)
            elif config.Learn.predict_target == 'PlayerLocalId':
                config.Learn.apply_pid = False
                target_data = np.asarray(player_id_t0_batch)
            else:
                raise ValueError('unknown predict_target {0}'.format(
                    config.Learn.predict_target))

            if config.Learn.apply_pid:
                input_data = np.concatenate([
                    np.asarray(action_id_t0),
                    np.asarray(s_t0_batch),
                    np.asarray(player_id_t0_batch)
                ],
                                            axis=2)
            else:
                input_data = np.concatenate(
                    [np.asarray(action_id_t0),
                     np.asarray(s_t0_batch)], axis=2)

            for i in range(0, len(batch_return)):
                terminal = batch_return[i][-2]
                # cut = batch_return[i][8]

            [output_prob] = sess.run(
                [model.read_out],
                feed_dict={
                    model.rnn_input_ph: input_data,
                    model.y_ph: target_data,
                    model.trace_lengths_ph: trace_lengths
                })
            if output_decoder_all is None:
                output_decoder_all = output_prob
                target_data_all = target_data
            else:
                output_decoder_all = np.concatenate(
                    [output_decoder_all, output_prob], axis=0)
                target_data_all = np.concatenate(
                    [target_data_all, target_data], axis=0)
            s_t0 = s_tl
            if terminal:
                # save progress after a game
                # model.saver.save(sess, saved_network + '/' + config.learn.sport + '-game-',
                #                  global_step=game_number)
                # v_diff_record_average = sum(v_diff_record) / len(v_diff_record)
                # game_diff_record_dict.update({dir_game: v_diff_record_average})
                break

    acc = compute_acc(output_prob=output_decoder_all,
                      target_label=target_data_all,
                      if_print=True)
    print("testing acc is {0}".format(str(acc)))
def validation_model(testing_dir_games_all, data_store, config, sess, model,
                     predicted_target, embedding2validate):
    model_output_all = None
    target_data_all = None
    selection_matrix_all = None
    print('validating model')
    game_number = 0
    for dir_game in testing_dir_games_all:

        if dir_game == '.DS_Store':
            continue
        game_number += 1
        state_trace_length, state_input, reward, action, team_id, player_index = get_icehockey_game_data(
            data_store=data_store, dir_game=dir_game, config=config)
        # state_trace_length = np.asarray([10] * len(state_trace_length))
        action_seq = transfer2seq(data=action,
                                  trace_length=state_trace_length,
                                  max_length=config.Learn.max_seq_length)
        team_id_seq = transfer2seq(data=team_id,
                                   trace_length=state_trace_length,
                                   max_length=config.Learn.max_seq_length)
        player_id_seq = transfer2seq(data=player_index,
                                     trace_length=state_trace_length,
                                     max_length=config.Learn.max_seq_length)
        # print ("\n training file" + str(dir_game))
        # reward_count = sum(reward)
        # print ("reward number" + str(reward_count))
        if len(state_input) != len(reward) or len(state_trace_length) != len(
                reward):
            raise Exception('state length does not equal to reward length')

        train_len = len(state_input)
        train_number = 0
        s_t0 = state_input[train_number]
        train_number += 1

        while True:
            # try:
            batch_return, \
            train_number, \
            s_tl, \
            print_flag = get_together_training_batch(s_t0=s_t0,
                                                     state_input=state_input,
                                                     reward=reward,
                                                     player_index=player_index,
                                                     train_number=train_number,
                                                     train_len=train_len,
                                                     state_trace_length=state_trace_length,
                                                     action=action_seq,
                                                     team_id=team_id_seq,
                                                     config=config)

            # get the batch variables
            # s_t0, s_t1, r_t0_combine, s_length_t0, s_length_t1, action_id_t0, action_id_t1, team_id_t0,
            #                      team_id_t1, 0, 0
            s_t0_batch = [d[0] for d in batch_return]
            s_t1_batch = [d[1] for d in batch_return]
            r_t_batch = [d[2] for d in batch_return]
            trace_t0_batch = [d[3] for d in batch_return]
            trace_t1_batch = [d[4] for d in batch_return]
            action_id_t0_batch = [d[5] for d in batch_return]
            action_id_t1_batch = [d[6] for d in batch_return]
            team_id_t0_batch = [d[7] for d in batch_return]
            team_id_t1_batch = [d[8] for d in batch_return]
            player_id_t0_batch = [d[9] for d in batch_return]
            player_id_t1_batch = [d[10] for d in batch_return]
            r_t_seq_batch = transfer2seq(
                data=np.asarray(r_t_batch),
                trace_length=trace_t0_batch,
                max_length=config.Learn.max_seq_length)
            current_state, history_state = handle_de_history(
                data_seq_all=s_t0_batch, trace_lengths=trace_t0_batch)
            current_action, history_action = handle_de_history(
                data_seq_all=action_id_t0_batch, trace_lengths=trace_t0_batch)
            current_reward, history_reward = handle_de_history(
                data_seq_all=r_t_seq_batch, trace_lengths=trace_t0_batch)

            if embedding2validate:
                embed_data = []
                for player_id in player_id_t0_batch:
                    player_index_scalar = np.where(player_id == 1)[0][0]
                    embed_data.append(embedding2validate[player_index_scalar])
                embed_data = np.asarray(embed_data)
            else:
                embed_data = np.asarray(player_id_t0_batch)

            if predicted_target == 'action':
                input_seq_data = np.concatenate([
                    np.asarray(history_state),
                    np.asarray(history_action),
                    np.asarray(history_reward)
                ],
                                                axis=2)
                input_obs_data = np.concatenate(
                    [np.asarray(current_state),
                     np.asarray(current_reward)],
                    axis=1)
                target_data = np.asarray(current_action)
                trace_lengths = [tl - 1 for tl in trace_t0_batch
                                 ]  # reduce 1 from trace length
            elif predicted_target == 'state':
                input_seq_data = np.concatenate([
                    np.asarray(history_state),
                    np.asarray(history_action),
                    np.asarray(history_reward)
                ],
                                                axis=2)
                input_obs_data = np.concatenate(
                    [np.asarray(current_action),
                     np.asarray(current_reward)],
                    axis=1)
                target_data = np.asarray(current_state)
                trace_lengths = [tl - 1 for tl in trace_t0_batch
                                 ]  # reduce 1 from trace length
            elif predicted_target == 'reward':
                input_seq_data = np.concatenate([
                    np.asarray(history_state),
                    np.asarray(history_action),
                    np.asarray(history_reward)
                ],
                                                axis=2)
                input_obs_data = np.concatenate(
                    [np.asarray(current_state),
                     np.asarray(current_action)],
                    axis=1)
                target_data = np.asarray(current_reward)
                trace_lengths = [tl - 1 for tl in trace_t0_batch
                                 ]  # reduce 1 from trace length
            else:
                raise ValueError('undefined predicted target')

            for i in range(0, len(batch_return)):
                terminal = batch_return[i][-2]
                # cut = batch_return[i][8]

            [output_prob, _] = sess.run(
                [model.read_out, model.train_op],
                feed_dict={
                    model.rnn_input_ph: input_seq_data,
                    model.feature_input_ph: input_obs_data,
                    model.y_ph: target_data,
                    model.embed_label_ph: embed_data,
                    model.trace_lengths_ph: trace_lengths
                })
            if model_output_all is None:
                model_output_all = output_prob
                target_data_all = target_data
            else:
                model_output_all = np.concatenate(
                    [model_output_all, output_prob], axis=0)
                target_data_all = np.concatenate(
                    [target_data_all, target_data], axis=0)
            s_t0 = s_tl
            if terminal:
                # save progress after a game
                # model.saver.save(sess, saved_network + '/' + config.learn.sport + '-game-',
                #                  global_step=game_number)
                # v_diff_record_average = sum(v_diff_record) / len(v_diff_record)
                # game_diff_record_dict.update({dir_game: v_diff_record_average})
                break
    if predicted_target == 'action':
        acc = compute_acc(output_prob=model_output_all,
                          target_label=target_data_all,
                          if_print=True)
        print("testing acc is {0}".format(str(acc)))
    else:
        mae = compute_mae(output_actions_prob=model_output_all,
                          target_actions_prob=target_data_all,
                          if_print=True)
        print("mae is {0}".format(str(mae)))
def validate_model(testing_dir_games_all,
                   data_store,
                   source_data_dir,
                   config,
                   sess,
                   model,
                   player_id_cluster_dir,
                   train_game_number,
                   validate_cvrnn_flag,
                   validate_td_flag,
                   validate_diff_flag,
                   validate_pred_flag,
                   file_writer=None):
    output_decoder_all = None
    target_data_all = None
    selection_matrix_all = None
    q_values_all = None
    validate_variance_flag = False
    pred_target_data_all = None
    pred_output_prob_all = None

    if validate_diff_flag:
        real_label_record = np.ones([len(testing_dir_games_all), 5000]) * -1
        output_label_record = np.ones([len(testing_dir_games_all), 5000]) * -1

    print('validating model')
    for dir_index in range(0, len(testing_dir_games_all)):

        real_label_all = None
        output_label_all = None

        dir_game = testing_dir_games_all[dir_index]
        print('validating game {0}'.format(str(dir_game)))
        if dir_game == '.DS_Store':
            continue

        [
            output_decoder_all, target_data_all, selection_matrix_all,
            q_values_all, real_label_all, output_label_all,
            pred_target_data_all, pred_output_prob_all,
            match_q_values_players_dict
        ] = gathering_data_and_run(
            dir_game,
            config,
            player_id_cluster_dir,
            data_store,
            source_data_dir,
            model,
            sess,
            training_flag=False,
            game_number=None,
            validate_cvrnn_flag=validate_cvrnn_flag,
            validate_td_flag=validate_td_flag,
            validate_diff_flag=validate_diff_flag,
            validate_variance_flag=validate_variance_flag,
            validate_predict_flag=validate_pred_flag,
            output_decoder_all=output_decoder_all,
            target_data_all=target_data_all,
            selection_matrix_all=selection_matrix_all,
            q_values_all=q_values_all,
            output_label_all=output_label_all,
            real_label_all=real_label_all,
            pred_target_data_all=pred_target_data_all,
            pred_output_prob_all=pred_output_prob_all)
        # validate_variance_flag = False
        # if match_q_values_players_dict is not None:
        #     plot_players_games(match_q_values_players_dict, train_game_number)

        if validate_diff_flag:
            real_label_record[dir_index][:len(
                real_label_all)] = real_label_all[:len(real_label_all)]
            output_label_record[dir_index][:len(
                output_label_all)] = output_label_all[:len(output_label_all)]

    if validate_cvrnn_flag:
        acc = compute_rnn_acc(output_prob=output_decoder_all,
                              target_label=target_data_all,
                              selection_matrix=selection_matrix_all,
                              config=config,
                              if_print=True)
        print("testing acc is {0}".format(str(acc)))
        if file_writer is not None:
            file_writer.write("testing acc is {0}\n".format(str(acc)))
    # if validate_td_flag:
    #     print ("testing avg qs is {0}".format(str(np.mean(q_values_all, axis=0))))
    #     if file_writer is not None:
    #         file_writer.write("testing avg qs is {0}\n".format(str(np.mean(q_values_all, axis=0))))

    if validate_diff_flag:
        # print ('general real label is {0}'.format(str(np.sum(real_label_record, axis=1))))
        # print ('general output label is {0}'.format(str(np.sum(output_label_record, axis=1))))
        for i in range(0, output_label_record.shape[1]):
            real_outcome_record_step = real_label_record[:, i]
            model_output_record_step = output_label_record[:, i]
            diff_sum = 0
            total_number = 0
            print_flag = True
            for win_index in range(0, len(real_outcome_record_step)):
                if model_output_record_step[
                        win_index] == -100 or real_outcome_record_step[
                            win_index] == -100:
                    print_flag = True
                    continue
                diff = abs(model_output_record_step[win_index] -
                           real_outcome_record_step[win_index])
                diff_sum += diff
                total_number += 1
            if print_flag:
                if i % 100 == 0 and total_number > 0:
                    print('diff of time {0} is {1}'.format(
                        str(i), str(float(diff_sum) / total_number)))
                    if file_writer is not None:
                        file_writer.write('diff of time {0} is {1}\n'.format(
                            str(i), str(float(diff_sum) / total_number)))
    if validate_pred_flag:
        TP, TN, FP, FN, acc, ll, auc = compute_acc(pred_target_data_all,
                                                   pred_output_prob_all,
                                                   if_binary_result=True,
                                                   if_print=False,
                                                   if_add_ll=True)
        precision = float(TP) / (TP + FP)
        recall = float(TP) / (TP + FN)
        print("Prediction acc is {0} with precision {1} and recall {2}".format(
            acc, precision, recall))