def gather_plot_values(dir_games_all, data_store, config, model):

    mu_all = []
    var_all = []

    for dir_game in dir_games_all:
        if dir_game == '.DS_Store':
            continue
        state_trace_length, state_input, reward, action, team_id, player_index = get_icehockey_game_data(
            data_store=data_store,
            dir_game=dir_game,
            config=config,
            player_id_cluster_dir=None)
        action_seq = transfer2seq(data=action,
                                  trace_length=state_trace_length,
                                  max_length=config.Learn.max_seq_length)
        team_id_seq = transfer2seq(data=team_id,
                                   trace_length=state_trace_length,
                                   max_length=config.Learn.max_seq_length)
        player_index_seq = transfer2seq(data=player_index,
                                        trace_length=state_trace_length,
                                        max_length=config.Learn.max_seq_length)

        input_data = np.concatenate([state_input, action_seq], axis=2)

        [mu, var] = sess.run(
            [model.mu_out, model.var_out],
            feed_dict={
                model.rnn_input_ph: input_data,
                model.trace_lengths_ph: state_trace_length
            })
        mu_all.append(mu)
        var_all.append(var)
    return mu_all, var_all
def gathering_running_and_run(dir_game, config, player_id_cluster_dir, data_store,
                              model, sess, training_flag, game_number, validate_cvrnn_flag=False,
                              validate_td_flag=False, validate_variance_flag=False,
                              output_decoder_all=None,
                              target_data_all=None, selection_matrix_all=None,
                              q_values_all=None, pretrain_flag=None):
    if validate_variance_flag:
        match_q_values_players_dict = {}
        for i in range(config.Learn.player_cluster_number):
            match_q_values_players_dict.update({i: []})
    else:
        match_q_values_players_dict = None

    state_trace_length, state_input, reward, action, team_id, player_index = get_icehockey_game_data(
        data_store=data_store, dir_game=dir_game, config=config, player_id_cluster_dir=player_id_cluster_dir)
    action_seq = transfer2seq(data=action, trace_length=state_trace_length,
                              max_length=config.Learn.max_seq_length)
    team_id_seq = transfer2seq(data=team_id, trace_length=state_trace_length,
                               max_length=config.Learn.max_seq_length)
    player_index_seq = transfer2seq(data=player_index, trace_length=state_trace_length,
                                    max_length=config.Learn.max_seq_length)
    # reward_count = sum(reward)
    # print ("reward number" + str(reward_count))
    if len(state_input) != len(reward) or len(state_trace_length) != len(reward):
        raise Exception('state length does not equal to reward length')

    train_len = len(state_input)
    train_number = 0
    batch_number = 0
    s_t0 = state_input[train_number]
    train_number += 1
    while True:
        # try:
        batch_return, \
        train_number, \
        s_tl, \
        print_flag = get_together_training_batch(s_t0=s_t0,
                                                 state_input=state_input,
                                                 reward=reward,
                                                 player_index=player_index_seq,
                                                 train_number=train_number,
                                                 train_len=train_len,
                                                 state_trace_length=state_trace_length,
                                                 action=action_seq,
                                                 team_id=team_id_seq,
                                                 # win_info=None,
                                                 config=config)

        # get the batch variables
        # s_t0, s_t1, r_t0_combine, s_length_t0, s_length_t1, action_id_t0, action_id_t1, team_id_t0,
        #                      team_id_t1, 0, 0
        s_t0_batch = [d[0] for d in batch_return]
        s_t1_batch = [d[1] for d in batch_return]
        r_t_batch = [d[2] for d in batch_return]
        trace_t0_batch = [d[3] for d in batch_return]
        trace_t1_batch = [d[4] for d in batch_return]
        action_id_t0 = [d[5] for d in batch_return]
        action_id_t1 = [d[6] for d in batch_return]
        team_id_t0_batch = [d[7] for d in batch_return]
        team_id_t1_batch = [d[8] for d in batch_return]
        player_id_t0_batch = [d[9] for d in batch_return]
        player_id_t1_batch = [d[10] for d in batch_return]

        if config.Learn.apply_pid:
            input_data_t0 = np.concatenate([np.asarray(player_id_t0_batch), np.asarray(s_t0_batch),
                                            np.asarray(action_id_t0)], axis=2)
            input_data_t1 = np.concatenate([np.asarray(player_id_t1_batch), np.asarray(s_t1_batch),
                                            np.asarray(action_id_t1)], axis=2)
        else:
            input_data_t0 = np.concatenate([np.asarray(s_t0_batch), np.asarray(action_id_t0)], axis=2)
            input_data_t1 = np.concatenate([np.asarray(s_t1_batch), np.asarray(action_id_t1)], axis=2)

        trace_lengths_t0 = trace_t0_batch
        trace_lengths_t1 = trace_t1_batch

        for i in range(0, len(batch_return)):
            terminal = batch_return[i][-2]
            cut = batch_return[i][-1]

        if training_flag:
            # print (len(state_input) / (config.Learn.batch_size*10))
            print_flag = True if batch_number % (len(state_input) / (config.Learn.batch_size * 10)) == 0 else False
            train_td_model(model, sess, config, input_data_t0, trace_lengths_t0,
                           input_data_t1, trace_lengths_t1, r_t_batch, terminal, cut,
                           pretrain_flag, print_flag)
        else:
            pass

        batch_number += 1
        s_t0 = s_tl
        if terminal:
            break

    return [output_decoder_all, target_data_all, selection_matrix_all, q_values_all, match_q_values_players_dict]
def run_network(sess, model, config, training_dir_games_all,
                testing_dir_games_all, model_data_store_dir,
                player_id_cluster_dir, source_data_store_dir,
                save_network_dir):
    game_number = 0
    converge_flag = False
    saver = tf.train.Saver(max_to_keep=300)

    while True:
        game_diff_record_dict = {}
        iteration_now = game_number / config.Learn.number_of_total_game + 1
        game_diff_record_dict.update({"Iteration": iteration_now})
        if converge_flag:
            break
        elif game_number >= len(
                training_dir_games_all) * config.Learn.iterate_num:
            break
        # else:
        #     converge_flag = True
        for dir_game in training_dir_games_all:
            if dir_game == '.DS_Store':
                continue
            game_number += 1
            game_cost_record = []
            state_trace_length, state_input, reward, action, team_id, player_index = get_icehockey_game_data(
                data_store=model_data_store_dir,
                dir_game=dir_game,
                config=config,
                player_id_cluster_dir=player_id_cluster_dir)
            action_seq = transfer2seq(data=action,
                                      trace_length=state_trace_length,
                                      max_length=config.Learn.max_seq_length)
            team_id_seq = transfer2seq(data=team_id,
                                       trace_length=state_trace_length,
                                       max_length=config.Learn.max_seq_length)
            if config.Learn.predict_target == 'ActionGoal' or config.Learn.predict_target == 'Action':
                player_index_seq = transfer2seq(
                    data=player_index,
                    trace_length=state_trace_length,
                    max_length=config.Learn.max_seq_length)
            else:
                player_index_seq = player_index
            if config.Learn.predict_target == 'ActionGoal':
                actions_all = read_feature_within_events(
                    directory=dir_game,
                    data_path=source_data_store_dir,
                    feature_name='name')
                next_goal_label = []
                data_length = state_trace_length.shape[0]

                win_info = []
                new_reward = []
                new_action_seq = []
                new_state_input = []
                new_state_trace_length = []
                new_team_id_seq = []
                for action_index in range(0, data_length):
                    action = actions_all[action_index]
                    if 'shot' in action:
                        if action_index + 1 == data_length:
                            continue
                        new_reward.append(reward[action_index])
                        new_action_seq.append(action_seq[action_index])
                        new_state_input.append(state_input[action_index])
                        new_state_trace_length.append(
                            state_trace_length[action_index])
                        new_team_id_seq.append(team_id_seq[action_index])
                        if actions_all[action_index + 1] == 'goal':
                            # print(actions_all[action_index+1])
                            next_goal_label.append([1, 0])
                        else:
                            # print(actions_all[action_index + 1])
                            next_goal_label.append([0, 1])
                reward = np.asarray(new_reward)
                action_seq = np.asarray(new_action_seq)
                state_input = np.asarray(new_state_input)
                state_trace_length = np.asarray(new_state_trace_length)
                team_id_seq = np.asarray(new_team_id_seq)
                win_info = np.asarray(next_goal_label)
            elif config.Learn.predict_target == 'Action':
                win_info = action[1:, :]
                reward = reward[:-1]
                action_seq = action_seq[:-1, :, :]
                state_input = state_input[:-1, :, :]
                state_trace_length = state_trace_length[:-1]
                team_id_seq = team_id_seq[:-1, :, :]
            else:
                win_info = None
            print("\n training file" + str(dir_game))
            # reward_count = sum(reward)
            # print ("reward number" + str(reward_count))
            if len(state_input) != len(reward) or len(
                    state_trace_length) != len(reward):
                raise Exception('state length does not equal to reward length')

            train_len = len(state_input)
            train_number = 0
            s_t0 = state_input[train_number]
            train_number += 1

            while True:
                # try:
                batch_return, \
                train_number, \
                s_tl, \
                print_flag = get_together_training_batch(s_t0=s_t0,
                                                         state_input=state_input,
                                                         reward=reward,
                                                         player_index=player_index_seq,
                                                         train_number=train_number,
                                                         train_len=train_len,
                                                         state_trace_length=state_trace_length,
                                                         action=action_seq,
                                                         team_id=team_id_seq,
                                                         win_info=win_info,
                                                         config=config)

                # get the batch variables
                # s_t0, s_t1, r_t0_combine, s_length_t0, s_length_t1, action_id_t0, action_id_t1, team_id_t0,
                #                      team_id_t1, 0, 0
                s_t0_batch = [d[0] for d in batch_return]
                s_t1_batch = [d[1] for d in batch_return]
                r_t_batch = [d[2] for d in batch_return]
                trace_t0_batch = [d[3] for d in batch_return]
                trace_t1_batch = [d[4] for d in batch_return]
                action_id_t0 = [d[5] for d in batch_return]
                action_id_t1 = [d[6] for d in batch_return]
                team_id_t0_batch = [d[7] for d in batch_return]
                team_id_t1_batch = [d[8] for d in batch_return]
                player_id_t0_batch = [d[9] for d in batch_return]
                player_id_t1_batch = [d[10] for d in batch_return]
                win_info_t0_batch = [d[11] for d in batch_return]

                if config.Learn.predict_target == 'ActionGoal':
                    target_data = np.asarray(win_info_t0_batch)
                    m2balanced = True
                elif config.Learn.predict_target == 'Action':
                    target_data = np.asarray(win_info_t0_batch)
                    m2balanced = True
                elif config.Learn.predict_target == 'PlayerLocalId':
                    config.Learn.apply_pid = False
                    target_data = np.asarray(player_id_t0_batch)
                    m2balanced = False
                else:
                    raise ValueError('unknown predict_target {0}'.format(
                        config.Learn.predict_target))

                if config.Learn.apply_pid:
                    input_data = np.concatenate([
                        np.asarray(action_id_t0),
                        np.asarray(s_t0_batch),
                        np.asarray(player_id_t0_batch)
                    ],
                                                axis=2)
                else:
                    input_data = np.concatenate(
                        [np.asarray(action_id_t0),
                         np.asarray(s_t0_batch)],
                        axis=2)
                trace_lengths = trace_t0_batch

                for i in range(0, len(batch_return)):
                    terminal = batch_return[i][-2]
                    # cut = batch_return[i][8]

                if config.Learn.apply_stochastic:
                    for i in range(len(input_data)):
                        if m2balanced:
                            cache_label = 0 if target_data[i][0] == 0 else 1
                            BalancedMemoryBuffer.push([
                                input_data[i], target_data[i], trace_lengths[i]
                            ],
                                                      cache_label=cache_label)
                        else:
                            MemoryBuffer.push([
                                input_data[i], target_data[i], trace_lengths[i]
                            ])
                    if game_number <= 10:
                        s_t0 = s_tl
                        if terminal:
                            break
                        else:
                            continue
                    if m2balanced:
                        sampled_data = BalancedMemoryBuffer.sample(
                            batch_size=config.Learn.batch_size)
                    else:
                        sampled_data = MemoryBuffer.sample(
                            batch_size=config.Learn.batch_size)
                    sample_input_data = np.asarray(
                        [sampled_data[j][0] for j in range(len(sampled_data))])
                    sample_target_data = np.asarray(
                        [sampled_data[j][1] for j in range(len(sampled_data))])
                    sample_trace_lengths = np.asarray(
                        [sampled_data[j][2] for j in range(len(sampled_data))])
                else:
                    sample_input_data = input_data
                    sample_target_data = target_data
                    sample_trace_lengths = trace_lengths

                train_model(model, sess, config, sample_input_data,
                            sample_target_data, sample_trace_lengths, terminal)
                s_t0 = s_tl
                if terminal:
                    # save progress after a game
                    # model.saver.save(sess, saved_network + '/' + config.learn.sport + '-game-',
                    #                  global_step=game_number)
                    # v_diff_record_average = sum(v_diff_record) / len(v_diff_record)
                    # game_diff_record_dict.update({dir_game: v_diff_record_average})
                    break

            save_model(game_number, saver, sess, save_network_dir, config)

            if game_number % 100 == 1:
                validation_model(testing_dir_games_all, model_data_store_dir,
                                 config, sess, model, player_id_cluster_dir,
                                 source_data_store_dir)
def validation_model(testing_dir_games_all, data_store, config, sess, model,
                     player_id_cluster_dir, source_data_store_dir):
    output_decoder_all = None
    target_data_all = None
    selection_matrix_all = None
    print('validating model')
    game_number = 0
    for dir_game in testing_dir_games_all:

        if dir_game == '.DS_Store':
            continue
        game_number += 1
        state_trace_length, state_input, reward, action, team_id, player_index = get_icehockey_game_data(
            data_store=data_store,
            dir_game=dir_game,
            config=config,
            player_id_cluster_dir=player_id_cluster_dir)
        action_seq = transfer2seq(data=action,
                                  trace_length=state_trace_length,
                                  max_length=config.Learn.max_seq_length)
        team_id_seq = transfer2seq(data=team_id,
                                   trace_length=state_trace_length,
                                   max_length=config.Learn.max_seq_length)
        if config.Learn.predict_target == 'ActionGoal' or config.Learn.predict_target == 'Action':
            player_index_seq = transfer2seq(
                data=player_index,
                trace_length=state_trace_length,
                max_length=config.Learn.max_seq_length)
        else:
            player_index_seq = player_index
        if config.Learn.predict_target == 'ActionGoal':
            actions_all = read_feature_within_events(
                directory=dir_game,
                data_path=source_data_store_dir,
                feature_name='name')
            next_goal_label = []
            data_length = state_trace_length.shape[0]

            win_info = []
            new_reward = []
            new_action_seq = []
            new_state_input = []
            new_state_trace_length = []
            new_team_id_seq = []
            for action_index in range(0, data_length):
                action = actions_all[action_index]
                if 'shot' in action:
                    new_reward.append(reward[action_index])
                    new_action_seq.append(action_seq[action_index])
                    new_state_input.append(state_input[action_index])
                    new_state_trace_length.append(
                        state_trace_length[action_index])
                    new_team_id_seq.append(team_id_seq[action_index])
                    if action_index + 1 == data_length:
                        continue
                    if actions_all[action_index + 1] == 'goal':
                        # print(actions_all[action_index+1])
                        next_goal_label.append([1, 0])
                    else:
                        # print(actions_all[action_index + 1])
                        next_goal_label.append([0, 1])
            reward = np.asarray(new_reward)
            action_seq = np.asarray(new_action_seq)
            state_input = np.asarray(new_state_input)
            state_trace_length = np.asarray(new_state_trace_length)
            team_id_seq = np.asarray(new_team_id_seq)
            win_info = np.asarray(next_goal_label)
        elif config.Learn.predict_target == 'Action':
            win_info = action[1:, :]
            reward = reward[:-1]
            action_seq = action_seq[:-1, :, :]
            state_input = state_input[:-1, :, :]
            state_trace_length = state_trace_length[:-1]
            team_id_seq = team_id_seq[:-1, :, :]
        else:
            win_info = None
        # print ("\n training file" + str(dir_game))
        # reward_count = sum(reward)
        # print ("reward number" + str(reward_count))
        if len(state_input) != len(reward) or len(state_trace_length) != len(
                reward):
            raise Exception('state length does not equal to reward length')

        train_len = len(state_input)
        train_number = 0
        s_t0 = state_input[train_number]
        train_number += 1

        while True:
            # try:
            batch_return, \
            train_number, \
            s_tl, \
            print_flag = get_together_training_batch(s_t0=s_t0,
                                                     state_input=state_input,
                                                     reward=reward,
                                                     player_index=player_index_seq,
                                                     train_number=train_number,
                                                     train_len=train_len,
                                                     state_trace_length=state_trace_length,
                                                     action=action_seq,
                                                     team_id=team_id_seq,
                                                     win_info=win_info,
                                                     config=config)

            # get the batch variables
            # s_t0, s_t1, r_t0_combine, s_length_t0, s_length_t1, action_id_t0, action_id_t1, team_id_t0,
            #                      team_id_t1, 0, 0
            s_t0_batch = [d[0] for d in batch_return]
            s_t1_batch = [d[1] for d in batch_return]
            r_t_batch = [d[2] for d in batch_return]
            trace_t0_batch = [d[3] for d in batch_return]
            trace_t1_batch = [d[4] for d in batch_return]
            action_id_t0 = [d[5] for d in batch_return]
            action_id_t1 = [d[6] for d in batch_return]
            team_id_t0_batch = [d[7] for d in batch_return]
            team_id_t1_batch = [d[8] for d in batch_return]
            player_id_t0_batch = [d[9] for d in batch_return]
            player_id_t1_batch = [d[10] for d in batch_return]
            win_info_t0_batch = [d[11] for d in batch_return]

            trace_lengths = trace_t0_batch

            if config.Learn.predict_target == 'ActionGoal':
                target_data = np.asarray(win_info_t0_batch)
            elif config.Learn.predict_target == 'Action':
                target_data = np.asarray(win_info_t0_batch)
            elif config.Learn.predict_target == 'PlayerLocalId':
                config.Learn.apply_pid = False
                target_data = np.asarray(player_id_t0_batch)
            else:
                raise ValueError('unknown predict_target {0}'.format(
                    config.Learn.predict_target))

            if config.Learn.apply_pid:
                input_data = np.concatenate([
                    np.asarray(action_id_t0),
                    np.asarray(s_t0_batch),
                    np.asarray(player_id_t0_batch)
                ],
                                            axis=2)
            else:
                input_data = np.concatenate(
                    [np.asarray(action_id_t0),
                     np.asarray(s_t0_batch)], axis=2)

            for i in range(0, len(batch_return)):
                terminal = batch_return[i][-2]
                # cut = batch_return[i][8]

            [output_prob] = sess.run(
                [model.read_out],
                feed_dict={
                    model.rnn_input_ph: input_data,
                    model.y_ph: target_data,
                    model.trace_lengths_ph: trace_lengths
                })
            if output_decoder_all is None:
                output_decoder_all = output_prob
                target_data_all = target_data
            else:
                output_decoder_all = np.concatenate(
                    [output_decoder_all, output_prob], axis=0)
                target_data_all = np.concatenate(
                    [target_data_all, target_data], axis=0)
            s_t0 = s_tl
            if terminal:
                # save progress after a game
                # model.saver.save(sess, saved_network + '/' + config.learn.sport + '-game-',
                #                  global_step=game_number)
                # v_diff_record_average = sum(v_diff_record) / len(v_diff_record)
                # game_diff_record_dict.update({dir_game: v_diff_record_average})
                break

    acc = compute_acc(output_prob=output_decoder_all,
                      target_label=target_data_all,
                      if_print=True)
    print("testing acc is {0}".format(str(acc)))
示例#5
0
def gathering_data_and_run(
    dir_game,
    config,
    player_id_cluster_dir,
    data_store,
    source_data_dir,
    model,
    sess,
    training_flag,
    game_number,
    validate_cvrnn_flag=False,
    validate_td_flag=False,
    validate_diff_flag=False,
    validate_variance_flag=False,
    validate_predict_flag=False,
    output_decoder_all=None,
    target_data_all=None,
    selection_matrix_all=None,
    q_values_all=None,
    output_label_all=None,
    real_label_all=None,
    pred_target_data_all=None,
    pred_output_prob_all=None,
    training_file=None,
):
    if validate_variance_flag:
        match_q_values_players_dict = {}
        for i in range(config.Learn.player_cluster_number):
            match_q_values_players_dict.update({i: []})
    else:
        match_q_values_players_dict = None

    state_trace_length, state_input, reward, action, team_id, player_index = get_icehockey_game_data(
        data_store=data_store,
        dir_game=dir_game,
        config=config,
        player_id_cluster_dir=player_id_cluster_dir)
    action_seq = transfer2seq(data=action,
                              trace_length=state_trace_length,
                              max_length=config.Learn.max_seq_length)
    reward_seq = transfer2seq(data=np.expand_dims(reward, axis=-1),
                              trace_length=state_trace_length,
                              max_length=config.Learn.max_seq_length)
    team_id_seq = transfer2seq(data=team_id,
                               trace_length=state_trace_length,
                               max_length=config.Learn.max_seq_length)
    player_index_seq = transfer2seq(data=player_index,
                                    trace_length=state_trace_length,
                                    max_length=config.Learn.max_seq_length)
    # win_one_hot = compute_game_win_vec(rewards=reward)
    score_diff = compute_game_score_diff_vec(rewards=reward)

    score_difference_game = read_feature_within_events(
        dir_game,
        source_data_dir,
        'scoreDifferential',
        transfer_home_number=True,
        data_store=data_store)
    state_reward_input = np.concatenate(
        [reward_seq, state_input],
        axis=-1)  # concatenate the sequence of state and reward.

    add_pred_flag = False

    if config.Arch.Predict.predict_target == 'ActionGoal':
        add_pred_flag = True
        actions_all = read_feature_within_events(directory=dir_game,
                                                 data_path=source_data_dir,
                                                 feature_name='name')
        next_goal_label = []
        data_length = state_trace_length.shape[0]
        new_reward = []
        new_action_seq = []
        new_state_input = []
        new_state_trace_length = []
        new_team_id_seq = []
        new_player_index_seq = []
        for action_index in range(0, data_length):
            action = actions_all[action_index]
            if 'shot' in action:
                if action_index + 1 == data_length:
                    continue
                new_reward.append(reward[action_index])
                new_action_seq.append(action_seq[action_index])
                new_state_input.append(state_reward_input[action_index])
                new_state_trace_length.append(state_trace_length[action_index])
                new_team_id_seq.append(team_id_seq[action_index])
                new_player_index_seq.append(player_index_seq[action_index])
                if actions_all[action_index + 1] == 'goal':
                    # print(actions_all[action_index+1])
                    next_goal_label.append([1, 0])
                else:
                    # print(actions_all[action_index + 1])
                    next_goal_label.append([0, 1])
        pred_target = next_goal_label
    elif config.Arch.Predict.predict_target == 'Action':
        add_pred_flag = True
        pred_target = action[1:, :]
        new_reward = reward[:-1]
        new_action_seq = action_seq[:-1, :, :]
        new_state_input = state_reward_input[:-1, :, :]
        new_state_trace_length = state_trace_length[:-1]
        new_team_id_seq = team_id_seq[:-1, :, :]
        new_player_index_seq = player_index_seq[:-1, :, :]
    else:
        # raise ValueError()
        add_pred_flag = False

    if add_pred_flag:
        if training_flag:
            train_mask = np.asarray([[[1]] * config.Learn.max_seq_length] *
                                    len(new_state_input))
        else:
            train_mask = np.asarray([[[0]] * config.Learn.max_seq_length] *
                                    len(new_state_input))
        if config.Learn.predict_target == 'PlayerLocalId':
            pred_input_data = np.concatenate([
                np.asarray(new_player_index_seq),
                np.asarray(new_team_id_seq),
                np.asarray(new_state_input),
                np.asarray(new_action_seq), train_mask
            ],
                                             axis=2)
            pred_target_data = np.asarray(np.asarray(pred_target))
            pred_trace_lengths = new_state_trace_length
            pred_selection_matrix = generate_selection_matrix(
                new_state_trace_length,
                max_trace_length=config.Learn.max_seq_length)
        else:
            pred_input_data = np.concatenate([
                np.asarray(new_player_index_seq),
                np.asarray(new_state_input),
                np.asarray(new_action_seq), train_mask
            ],
                                             axis=2)
            pred_target_data = np.asarray(np.asarray(pred_target))
            pred_trace_lengths = new_state_trace_length
            pred_selection_matrix = generate_selection_matrix(
                new_state_trace_length,
                max_trace_length=config.Learn.max_seq_length)
        if training_flag:
            for i in range(len(new_state_input)):
                cache_label = np.argmax(pred_target_data[i], axis=0)
                Prediction_MemoryBuffer.push([
                    pred_input_data[i], pred_target_data[i],
                    pred_trace_lengths[i], pred_selection_matrix[i]
                ], cache_label)

    # reward_count = sum(reward)
    # print ("reward number" + str(reward_count))
    if len(state_reward_input) != len(reward) or len(
            state_trace_length) != len(reward):
        raise Exception('state length does not equal to reward length')

    kl_loss_game = []
    ll_game = []

    train_len = len(state_reward_input)
    train_number = 0
    s_t0 = state_reward_input[train_number]
    train_number += 1
    while True:
        # try:
        batch_return, \
        train_number, \
        s_tl, \
        print_flag = get_together_training_batch(s_t0=s_t0,
                                                 state_input=state_reward_input,
                                                 reward=reward,
                                                 player_index=player_index_seq,
                                                 train_number=train_number,
                                                 train_len=train_len,
                                                 state_trace_length=state_trace_length,
                                                 action=action_seq,
                                                 team_id=team_id_seq,
                                                 win_info=score_diff,
                                                 score_info=score_difference_game,
                                                 config=config)

        # get the batch variables
        # s_t0, s_t1, r_t0_combine, s_length_t0, s_length_t1, action_id_t0, action_id_t1, team_id_t0,
        #                      team_id_t1, 0, 0
        s_t0_batch = [d[0] for d in batch_return]
        s_t1_batch = [d[1] for d in batch_return]
        r_t_batch = [d[2] for d in batch_return]
        trace_t0_batch = [d[3] for d in batch_return]
        trace_t1_batch = [d[4] for d in batch_return]
        action_id_t0 = [d[5] for d in batch_return]
        action_id_t1 = [d[6] for d in batch_return]
        team_id_t0_batch = [d[7] for d in batch_return]
        team_id_t1_batch = [d[8] for d in batch_return]
        player_id_t0_batch = [d[9] for d in batch_return]
        player_id_t1_batch = [d[10] for d in batch_return]
        win_id_t_batch = [d[11] for d in batch_return]
        terminal_batch = [d[-2] for d in batch_return]
        cut_batch = [d[-1] for d in batch_return]
        score_diff_t_batch = [d[11] for d in batch_return]
        score_diff_base_t0_batch = [d[12] for d in batch_return]
        outcome_data = score_diff_t_batch
        score_diff_base_t0 = score_diff_base_t0_batch
        if training_flag:
            train_mask = np.asarray([[[1]] * config.Learn.max_seq_length] *
                                    len(s_t0_batch))
        else:
            train_mask = np.asarray([[[0]] * config.Learn.max_seq_length] *
                                    len(s_t0_batch))
        # (player_id, state ,action flag)
        for i in range(0, len(terminal_batch)):
            terminal = terminal_batch[i]
            cut = cut_batch[i]

        input_data_t0 = np.concatenate([
            np.asarray(player_id_t0_batch),
            np.asarray(team_id_t0_batch),
            np.asarray(s_t0_batch),
            np.asarray(action_id_t0), train_mask
        ],
                                       axis=2)
        target_data_t0 = np.asarray(np.asarray(player_id_t0_batch))
        trace_lengths_t0 = trace_t0_batch
        selection_matrix_t0 = generate_selection_matrix(
            trace_lengths_t0, max_trace_length=config.Learn.max_seq_length)

        input_data_t1 = np.concatenate([
            np.asarray(player_id_t1_batch),
            np.asarray(team_id_t1_batch),
            np.asarray(s_t1_batch),
            np.asarray(action_id_t1), train_mask
        ],
                                       axis=2)
        target_data_t1 = np.asarray(np.asarray(player_id_t1_batch))
        trace_lengths_t1 = trace_t1_batch
        selection_matrix_t1 = generate_selection_matrix(
            trace_t1_batch, max_trace_length=config.Learn.max_seq_length)
        if training_flag:

            if config.Learn.apply_stochastic:
                for i in range(len(input_data_t0)):
                    General_MemoryBuffer.push([
                        input_data_t0[i], target_data_t0[i],
                        trace_lengths_t0[i], selection_matrix_t0[i],
                        input_data_t1[i], target_data_t1[i],
                        trace_lengths_t1[i], selection_matrix_t1[i],
                        r_t_batch[i], win_id_t_batch[i], terminal_batch[i],
                        cut_batch[i]
                    ])
                sampled_data = General_MemoryBuffer.sample(
                    batch_size=config.Learn.batch_size)
                sample_input_data_t0 = np.asarray(
                    [sampled_data[j][0] for j in range(len(sampled_data))])
                sample_target_data_t0 = np.asarray(
                    [sampled_data[j][1] for j in range(len(sampled_data))])
                sample_trace_lengths_t0 = np.asarray(
                    [sampled_data[j][2] for j in range(len(sampled_data))])
                sample_selection_matrix_t0 = np.asarray(
                    [sampled_data[j][3] for j in range(len(sampled_data))])
                # sample_input_data_t1 = np.asarray([sampled_data[j][4] for j in range(len(sampled_data))])
                # sample_target_data_t1 = np.asarray([sampled_data[j][5] for j in range(len(sampled_data))])
                # sample_trace_lengths_t1 = np.asarray([sampled_data[j][6] for j in range(len(sampled_data))])
                # sample_selection_matrix_t1 = np.asarray([sampled_data[j][7] for j in range(len(sampled_data))])
                # sample_r_t_batch = np.asarray([sampled_data[j][8] for j in range(len(sampled_data))])
                # sample_terminal_batch = np.asarray([sampled_data[j][10] for j in range(len(sampled_data))])
                # sample_cut_batch = np.asarray([sampled_data[j][11] for j in range(len(sampled_data))])
                sampled_outcome_t = np.asarray(
                    [sampled_data[j][9] for j in range(len(sampled_data))])

                for i in range(0, len(terminal_batch)):
                    batch_terminal = terminal_batch[i]
                    batch_cut = cut_batch[i]
            else:
                sample_input_data_t0 = input_data_t0
                sample_target_data_t0 = target_data_t0
                sample_trace_lengths_t0 = trace_lengths_t0
                sample_selection_matrix_t0 = selection_matrix_t0
                sampled_outcome_t = win_id_t_batch

            kld_all, ll_all = train_cvrnn_model(model, sess, config,
                                                sample_input_data_t0,
                                                sample_target_data_t0,
                                                sample_trace_lengths_t0,
                                                sample_selection_matrix_t0)
            kl_loss_game += kld_all
            ll_game += ll_all

            # """we skip sampling for TD learning"""
            # train_td_model(model, sess, config, input_data_t0, trace_lengths_t0, selection_matrix_t0,
            #                input_data_t1, trace_lengths_t1, selection_matrix_t1, r_t_batch, terminal, cut)
            #
            # train_score_diff(model, sess, config, input_data_t0, trace_lengths_t0, selection_matrix_t0,
            #                  input_data_t1, trace_lengths_t1, selection_matrix_t1, r_t_batch, outcome_data,
            #                  score_diff_base_t0, terminal, cut)

            # if add_pred_flag:
            #     sampled_data = Prediction_MemoryBuffer.sample(batch_size=config.Learn.batch_size)
            #     sample_pred_input_data = np.asarray([sampled_data[j][0] for j in range(len(sampled_data))])
            #     sample_pred_target_data = np.asarray([sampled_data[j][1] for j in range(len(sampled_data))])
            #     sample_pred_trace_lengths = np.asarray([sampled_data[j][2] for j in range(len(sampled_data))])
            #     sample_pred_selection_matrix = np.asarray([sampled_data[j][3] for j in range(len(sampled_data))])
            #
            #     train_prediction(model, sess, config,
            #                      sample_pred_input_data,
            #                      sample_pred_target_data,
            #                      sample_pred_trace_lengths,
            #                      sample_pred_selection_matrix)

        else:
            # for i in range(0, len(r_t_batch)):
            #     if i == len(r_t_batch) - 1:
            #         if terminal or cut:
            #             print(r_t_batch[i])
            if validate_cvrnn_flag:
                output_decoder = cvrnn_validation(sess, model, input_data_t0,
                                                  target_data_t0,
                                                  trace_lengths_t0,
                                                  selection_matrix_t0, config)

                if output_decoder_all is None:
                    output_decoder_all = output_decoder
                    target_data_all = target_data_t0
                    selection_matrix_all = selection_matrix_t0
                else:
                    # try:
                    output_decoder_all = np.concatenate(
                        [output_decoder_all, output_decoder], axis=0)
                    # except:
                    #     print output_decoder_all.shape
                    #     print  output_decoder.shape
                    target_data_all = np.concatenate(
                        [target_data_all, target_data_t0], axis=0)
                    selection_matrix_all = np.concatenate(
                        [selection_matrix_all, selection_matrix_t0], axis=0)

            if validate_td_flag:
                # validate_variance_flag = validate_variance_flag if train_number <= 500 else False
                q_values, match_q_values_players_dict = \
                    td_validation(sess, model, trace_lengths_t0, selection_matrix_t0,
                                  player_id_t0_batch, s_t0_batch, action_id_t0, input_data_t0,
                                  train_mask, config, match_q_values_players_dict,
                                  r_t_batch, terminal, cut, train_number, validate_variance_flag=False)

                if q_values_all is None:
                    q_values_all = q_values
                else:
                    q_values_all = np.concatenate([q_values_all, q_values],
                                                  axis=0)

            if validate_diff_flag:
                output_label, real_label = diff_validation(
                    sess, model, input_data_t0, trace_lengths_t0,
                    selection_matrix_t0, score_diff_base_t0, config,
                    outcome_data)
                if real_label_all is None:
                    real_label_all = real_label
                else:
                    real_label_all = np.concatenate(
                        [real_label_all, real_label], axis=0)

                if output_label_all is None:
                    output_label_all = output_label
                else:
                    output_label_all = np.concatenate(
                        [output_label_all, output_label], axis=0)

        s_t0 = s_tl
        if terminal:
            break

    if training_file is not None:
        training_file.write("Mean kl:{0}, Mean ll:{1}\n".format(
            np.mean(kl_loss_game), np.mean(ll_game)))
        training_file.flush()

    if validate_predict_flag and add_pred_flag:
        input_data, pred_output_prob = prediction_validation(
            model, sess, config, pred_input_data, pred_target_data,
            pred_trace_lengths, pred_selection_matrix)
        if pred_target_data_all is None:
            pred_target_data_all = pred_target_data
        else:
            pred_target_data_all = np.concatenate(
                [pred_target_data_all, pred_target_data], axis=0)

        if pred_output_prob_all is None:
            pred_output_prob_all = pred_output_prob
        else:
            pred_output_prob_all = np.concatenate(
                [pred_output_prob_all, pred_output_prob], axis=0)

    return [
        output_decoder_all, target_data_all, selection_matrix_all,
        q_values_all, real_label_all, output_label_all, pred_target_data_all,
        pred_output_prob_all, match_q_values_players_dict
    ]
def validation_model(testing_dir_games_all, data_store, config, sess, model,
                     predicted_target, embedding2validate):
    model_output_all = None
    target_data_all = None
    selection_matrix_all = None
    print('validating model')
    game_number = 0
    for dir_game in testing_dir_games_all:

        if dir_game == '.DS_Store':
            continue
        game_number += 1
        state_trace_length, state_input, reward, action, team_id, player_index = get_icehockey_game_data(
            data_store=data_store, dir_game=dir_game, config=config)
        # state_trace_length = np.asarray([10] * len(state_trace_length))
        action_seq = transfer2seq(data=action,
                                  trace_length=state_trace_length,
                                  max_length=config.Learn.max_seq_length)
        team_id_seq = transfer2seq(data=team_id,
                                   trace_length=state_trace_length,
                                   max_length=config.Learn.max_seq_length)
        player_id_seq = transfer2seq(data=player_index,
                                     trace_length=state_trace_length,
                                     max_length=config.Learn.max_seq_length)
        # print ("\n training file" + str(dir_game))
        # reward_count = sum(reward)
        # print ("reward number" + str(reward_count))
        if len(state_input) != len(reward) or len(state_trace_length) != len(
                reward):
            raise Exception('state length does not equal to reward length')

        train_len = len(state_input)
        train_number = 0
        s_t0 = state_input[train_number]
        train_number += 1

        while True:
            # try:
            batch_return, \
            train_number, \
            s_tl, \
            print_flag = get_together_training_batch(s_t0=s_t0,
                                                     state_input=state_input,
                                                     reward=reward,
                                                     player_index=player_index,
                                                     train_number=train_number,
                                                     train_len=train_len,
                                                     state_trace_length=state_trace_length,
                                                     action=action_seq,
                                                     team_id=team_id_seq,
                                                     config=config)

            # get the batch variables
            # s_t0, s_t1, r_t0_combine, s_length_t0, s_length_t1, action_id_t0, action_id_t1, team_id_t0,
            #                      team_id_t1, 0, 0
            s_t0_batch = [d[0] for d in batch_return]
            s_t1_batch = [d[1] for d in batch_return]
            r_t_batch = [d[2] for d in batch_return]
            trace_t0_batch = [d[3] for d in batch_return]
            trace_t1_batch = [d[4] for d in batch_return]
            action_id_t0_batch = [d[5] for d in batch_return]
            action_id_t1_batch = [d[6] for d in batch_return]
            team_id_t0_batch = [d[7] for d in batch_return]
            team_id_t1_batch = [d[8] for d in batch_return]
            player_id_t0_batch = [d[9] for d in batch_return]
            player_id_t1_batch = [d[10] for d in batch_return]
            r_t_seq_batch = transfer2seq(
                data=np.asarray(r_t_batch),
                trace_length=trace_t0_batch,
                max_length=config.Learn.max_seq_length)
            current_state, history_state = handle_de_history(
                data_seq_all=s_t0_batch, trace_lengths=trace_t0_batch)
            current_action, history_action = handle_de_history(
                data_seq_all=action_id_t0_batch, trace_lengths=trace_t0_batch)
            current_reward, history_reward = handle_de_history(
                data_seq_all=r_t_seq_batch, trace_lengths=trace_t0_batch)

            if embedding2validate:
                embed_data = []
                for player_id in player_id_t0_batch:
                    player_index_scalar = np.where(player_id == 1)[0][0]
                    embed_data.append(embedding2validate[player_index_scalar])
                embed_data = np.asarray(embed_data)
            else:
                embed_data = np.asarray(player_id_t0_batch)

            if predicted_target == 'action':
                input_seq_data = np.concatenate([
                    np.asarray(history_state),
                    np.asarray(history_action),
                    np.asarray(history_reward)
                ],
                                                axis=2)
                input_obs_data = np.concatenate(
                    [np.asarray(current_state),
                     np.asarray(current_reward)],
                    axis=1)
                target_data = np.asarray(current_action)
                trace_lengths = [tl - 1 for tl in trace_t0_batch
                                 ]  # reduce 1 from trace length
            elif predicted_target == 'state':
                input_seq_data = np.concatenate([
                    np.asarray(history_state),
                    np.asarray(history_action),
                    np.asarray(history_reward)
                ],
                                                axis=2)
                input_obs_data = np.concatenate(
                    [np.asarray(current_action),
                     np.asarray(current_reward)],
                    axis=1)
                target_data = np.asarray(current_state)
                trace_lengths = [tl - 1 for tl in trace_t0_batch
                                 ]  # reduce 1 from trace length
            elif predicted_target == 'reward':
                input_seq_data = np.concatenate([
                    np.asarray(history_state),
                    np.asarray(history_action),
                    np.asarray(history_reward)
                ],
                                                axis=2)
                input_obs_data = np.concatenate(
                    [np.asarray(current_state),
                     np.asarray(current_action)],
                    axis=1)
                target_data = np.asarray(current_reward)
                trace_lengths = [tl - 1 for tl in trace_t0_batch
                                 ]  # reduce 1 from trace length
            else:
                raise ValueError('undefined predicted target')

            for i in range(0, len(batch_return)):
                terminal = batch_return[i][-2]
                # cut = batch_return[i][8]

            [output_prob, _] = sess.run(
                [model.read_out, model.train_op],
                feed_dict={
                    model.rnn_input_ph: input_seq_data,
                    model.feature_input_ph: input_obs_data,
                    model.y_ph: target_data,
                    model.embed_label_ph: embed_data,
                    model.trace_lengths_ph: trace_lengths
                })
            if model_output_all is None:
                model_output_all = output_prob
                target_data_all = target_data
            else:
                model_output_all = np.concatenate(
                    [model_output_all, output_prob], axis=0)
                target_data_all = np.concatenate(
                    [target_data_all, target_data], axis=0)
            s_t0 = s_tl
            if terminal:
                # save progress after a game
                # model.saver.save(sess, saved_network + '/' + config.learn.sport + '-game-',
                #                  global_step=game_number)
                # v_diff_record_average = sum(v_diff_record) / len(v_diff_record)
                # game_diff_record_dict.update({dir_game: v_diff_record_average})
                break
    if predicted_target == 'action':
        acc = compute_acc(output_prob=model_output_all,
                          target_label=target_data_all,
                          if_print=True)
        print("testing acc is {0}".format(str(acc)))
    else:
        mae = compute_mae(output_actions_prob=model_output_all,
                          target_actions_prob=target_data_all,
                          if_print=True)
        print("mae is {0}".format(str(mae)))
def run_network(
    sess,
    model,
    config,
    training_dir_games_all,
    testing_dir_games_all,
    data_store,
    predicted_target,
    save_network_dir,
    validate_embedding_tag,
):
    if validate_embedding_tag is not None:
        validate_msg = save_network_dir.split('/')[-1].split('_')[0] + '_'
        save_embed_dir = save_network_dir.replace('de_embed_saved_networks', 'store_embedding'). \
            replace('de_model_saved_NN', 'de_model_save_embedding').replace(validate_msg, '')
        print('Applying embedding_matrix_game{0}.csv'.format(
            str(validate_embedding_tag)))
        with open(
                save_embed_dir + '/embedding_matrix_game{0}.csv'.format(
                    str(validate_embedding_tag)), 'r') as csv_file:

            csv_reader = csv.reader(csv_file, delimiter=',')
            embedding2validate = []
            for row in csv_reader:
                embedding2validate.append(row)
    else:
        embedding2validate = None

    game_number = 0
    converge_flag = False
    saver = tf.train.Saver(max_to_keep=300)

    while True:
        game_diff_record_dict = {}
        iteration_now = game_number / config.Learn.number_of_total_game + 1
        game_diff_record_dict.update({"Iteration": iteration_now})
        if converge_flag:
            break
        elif game_number >= len(
                training_dir_games_all) * config.Learn.iterate_num:
            break
        # else:
        #     converge_flag = True
        for dir_game in training_dir_games_all:
            if dir_game == '.DS_Store':
                continue
            print("\n training file" + str(dir_game))
            game_number += 1
            game_cost_record = []
            state_trace_length, state_input, reward, action, team_id, player_index = get_icehockey_game_data(
                data_store=data_store, dir_game=dir_game, config=config)
            # state_trace_length = np.asarray([10] * len(state_trace_length))
            action_seq = transfer2seq(data=action,
                                      trace_length=state_trace_length,
                                      max_length=config.Learn.max_seq_length)
            team_id_seq = transfer2seq(data=team_id,
                                       trace_length=state_trace_length,
                                       max_length=config.Learn.max_seq_length)
            player_id_seq = transfer2seq(
                data=player_index,
                trace_length=state_trace_length,
                max_length=config.Learn.max_seq_length)
            # reward_count = sum(reward)
            # print ("reward number" + str(reward_count))
            if len(state_input) != len(reward) or len(
                    state_trace_length) != len(reward):
                raise Exception('state length does not equal to reward length')

            train_len = len(state_input)
            train_number = 0
            s_t0 = state_input[train_number]
            train_number += 1

            while True:
                # try:
                batch_return, \
                train_number, \
                s_tl, \
                print_flag = get_together_training_batch(s_t0=s_t0,
                                                         state_input=state_input,
                                                         reward=reward,
                                                         player_index=player_index,
                                                         train_number=train_number,
                                                         train_len=train_len,
                                                         state_trace_length=state_trace_length,
                                                         action=action_seq,
                                                         team_id=team_id_seq,
                                                         config=config)

                # get the batch variables
                # s_t0, s_t1, r_t0_combine, s_length_t0, s_length_t1, action_id_t0, action_id_t1, team_id_t0,
                #                      team_id_t1, 0, 0
                s_t0_batch = [d[0] for d in batch_return]
                s_t1_batch = [d[1] for d in batch_return]
                r_t_batch = [d[2] for d in batch_return]
                trace_t0_batch = [d[3] for d in batch_return]
                trace_t1_batch = [d[4] for d in batch_return]
                action_id_t0_batch = [d[5] for d in batch_return]
                action_id_t1_batch = [d[6] for d in batch_return]
                team_id_t0_batch = [d[7] for d in batch_return]
                team_id_t1_batch = [d[8] for d in batch_return]
                player_id_t0_batch = [d[9] for d in batch_return]
                player_id_t1_batch = [d[10] for d in batch_return]

                r_t_seq_batch = transfer2seq(
                    data=np.asarray(r_t_batch),
                    trace_length=trace_t0_batch,
                    max_length=config.Learn.max_seq_length)

                current_state, history_state = handle_de_history(
                    data_seq_all=s_t0_batch, trace_lengths=trace_t0_batch)
                current_action, history_action = handle_de_history(
                    data_seq_all=action_id_t0_batch,
                    trace_lengths=trace_t0_batch)
                current_reward, history_reward = handle_de_history(
                    data_seq_all=r_t_seq_batch, trace_lengths=trace_t0_batch)
                if embedding2validate:
                    embed_data = []
                    for player_id in player_id_t0_batch:
                        player_index_scalar = np.where(player_id == 1)[0][0]
                        embed_data.append(
                            embedding2validate[player_index_scalar])
                    embed_data = np.asarray(embed_data)
                else:
                    embed_data = np.asarray(player_id_t0_batch)
                if predicted_target == 'action':
                    input_seq_data = np.concatenate([
                        np.asarray(history_state),
                        np.asarray(history_action),
                        np.asarray(history_reward)
                    ],
                                                    axis=2)
                    input_obs_data = np.concatenate([
                        np.asarray(current_state),
                        np.asarray(current_reward)
                    ],
                                                    axis=1)
                    target_data = np.asarray(current_action)
                    trace_lengths = [tl - 1 for tl in trace_t0_batch
                                     ]  # reduce 1 from trace length
                elif predicted_target == 'state':
                    input_seq_data = np.concatenate([
                        np.asarray(history_state),
                        np.asarray(history_action),
                        np.asarray(history_reward)
                    ],
                                                    axis=2)
                    input_obs_data = np.concatenate([
                        np.asarray(current_action),
                        np.asarray(current_reward)
                    ],
                                                    axis=1)
                    target_data = np.asarray(current_state)
                    trace_lengths = [tl - 1 for tl in trace_t0_batch
                                     ]  # reduce 1 from trace length
                elif predicted_target == 'reward':
                    input_seq_data = np.concatenate([
                        np.asarray(history_state),
                        np.asarray(history_action),
                        np.asarray(history_reward)
                    ],
                                                    axis=2)
                    input_obs_data = np.concatenate([
                        np.asarray(current_state),
                        np.asarray(current_action)
                    ],
                                                    axis=1)
                    target_data = np.asarray(current_reward)
                    trace_lengths = [tl - 1 for tl in trace_t0_batch
                                     ]  # reduce 1 from trace length
                else:
                    raise ValueError('undefined predicted target')

                for i in range(0, len(batch_return)):
                    terminal = batch_return[i][-2]
                    # cut = batch_return[i][8]

                # pretrain_flag = True if game_number <= 10 else False
                train_model(model, sess, config, input_seq_data,
                            input_obs_data, target_data, trace_lengths,
                            embed_data, terminal)
                s_t0 = s_tl
                if terminal:
                    # save progress after a game
                    # model.saver.save(sess, saved_network + '/' + config.learn.sport + '-game-',
                    #                  global_step=game_number)
                    # v_diff_record_average = sum(v_diff_record) / len(v_diff_record)
                    # game_diff_record_dict.update({dir_game: v_diff_record_average})
                    break
            if game_number % 100 == 1:
                if validate_embedding_tag is None:
                    collect_de_player_embeddings(
                        sess=sess,
                        model=model,
                        save_network_dir=save_network_dir,
                        game_number=game_number)
                if save_network_dir is not None:
                    print 'saving game {0} in {1}'.format(
                        str(game_number), save_network_dir)
                    saver.save(sess,
                               save_network_dir + '/' + config.Learn.sport +
                               '-game-',
                               global_step=game_number)
            if game_number % 10 == 1:
                validation_model(testing_dir_games_all, data_store, config,
                                 sess, model, predicted_target,
                                 embedding2validate)
def gathering_running_and_run(dir_game,
                              config,
                              player_id_cluster_dir,
                              data_store,
                              model,
                              sess,
                              training_flag,
                              game_number,
                              validate_cvrnn_flag=False,
                              validate_td_flag=False,
                              validate_variance_flag=False,
                              output_decoder_all=None,
                              target_data_all=None,
                              selection_matrix_all=None,
                              q_values_all=None):
    if validate_variance_flag:
        match_q_values_players_dict = {}
        for i in range(config.Learn.player_cluster_number):
            match_q_values_players_dict.update({i: []})
    else:
        match_q_values_players_dict = None

    state_trace_length, state_input, reward, action, team_id, player_index = get_icehockey_game_data(
        data_store=data_store,
        dir_game=dir_game,
        config=config,
        player_id_cluster_dir=player_id_cluster_dir)
    action_seq = transfer2seq(data=action,
                              trace_length=state_trace_length,
                              max_length=config.Learn.max_seq_length)
    team_id_seq = transfer2seq(data=team_id,
                               trace_length=state_trace_length,
                               max_length=config.Learn.max_seq_length)
    player_index_seq = transfer2seq(data=player_index,
                                    trace_length=state_trace_length,
                                    max_length=config.Learn.max_seq_length)
    # reward_count = sum(reward)
    # print ("reward number" + str(reward_count))
    if len(state_input) != len(reward) or len(state_trace_length) != len(
            reward):
        raise Exception('state length does not equal to reward length')

    initial_prior_mu_train = np.zeros(
        shape=[config.Learn.batch_size, config.Arch.CVRNN.latent_dim])
    previous_prior_mu_train = initial_prior_mu_train
    initial_prior_sigma_train = np.zeros(
        shape=[config.Learn.batch_size, config.Arch.CVRNN.latent_dim])
    previous_prior_sigma_train = initial_prior_sigma_train

    train_len = len(state_input)
    train_number = 0
    s_t0 = state_input[train_number]
    train_number += 1
    while True:
        # try:
        batch_return, \
        train_number, \
        s_t1, \
        print_flag = get_together_training_batch(s_t0=s_t0,
                                                 state_input=state_input,
                                                 reward=reward,
                                                 player_index=player_index_seq,
                                                 train_number=train_number,
                                                 train_len=train_len,
                                                 state_trace_length=state_trace_length,
                                                 action=action_seq,
                                                 team_id=team_id_seq,
                                                 config=config)

        # get the batch variables
        # s_t0, s_t1, r_t0_combine, s_length_t0, s_length_t1, action_id_t0, action_id_t1, team_id_t0,
        #                      team_id_t1, 0, 0
        s_t0_batch = [d[0] for d in batch_return]
        s_t1_batch = [d[1] for d in batch_return]
        r_t_batch = [d[2] for d in batch_return]
        trace_t0_batch = [d[3] for d in batch_return]
        trace_t1_batch = [d[4] for d in batch_return]
        action_id_t0 = [d[5] for d in batch_return]
        action_id_t1 = [d[6] for d in batch_return]
        team_id_t0_batch = [d[7] for d in batch_return]
        team_id_t1_batch = [d[8] for d in batch_return]
        player_id_t0_batch = [d[9] for d in batch_return]
        player_id_t1_batch = [d[10] for d in batch_return]
        terminal_batch = [d[-2] for d in batch_return]
        cut_batch = [d[-1] for d in batch_return]
        if training_flag:
            train_mask = np.asarray([[[1]] * config.Learn.max_seq_length] *
                                    len(s_t0_batch))
        else:
            train_mask = np.asarray([[[0]] * config.Learn.max_seq_length] *
                                    len(s_t0_batch))
        # (player_id, state ,action flag)
        for i in range(0, len(terminal_batch)):
            terminal = terminal_batch[i]
            cut = cut_batch[i]

        previous_prior_mu_train_seq = np.zeros(shape=[
            len(batch_return), config.Learn.max_seq_length,
            config.Arch.CVRNN.latent_dim
        ])
        previous_prior_mu_train_seq[:, 0, :] = initial_prior_mu_train
        previous_prior_sigma_train_seq = np.zeros(shape=[
            len(batch_return), config.Learn.max_seq_length,
            config.Arch.CVRNN.latent_dim
        ])
        previous_prior_sigma_train_seq[:, 0, :] = initial_prior_sigma_train

        if config.Learn.predict_target == 'PlayerLocalId':
            input_data_t0 = np.concatenate([
                np.asarray(player_id_t0_batch),
                np.asarray(team_id_t0_batch),
                np.asarray(s_t0_batch),
                np.asarray(action_id_t0), previous_prior_mu_train_seq,
                previous_prior_sigma_train_seq, train_mask
            ],
                                           axis=2)
            target_data_t0 = np.asarray(np.asarray(player_id_t0_batch))
            trace_lengths_t0 = trace_t0_batch
            selection_matrix_t0 = generate_selection_matrix(
                trace_lengths_t0, max_trace_length=config.Learn.max_seq_length)

            input_data_t1 = np.concatenate([
                np.asarray(player_id_t1_batch),
                np.asarray(team_id_t1_batch),
                np.asarray(s_t1_batch),
                np.asarray(action_id_t1), previous_prior_mu_train_seq,
                previous_prior_sigma_train_seq, train_mask
            ],
                                           axis=2)
            target_data_t1 = np.asarray(np.asarray(player_id_t1_batch))
            trace_lengths_t1 = trace_t1_batch
            selection_matrix_t1 = generate_selection_matrix(
                trace_t1_batch, max_trace_length=config.Learn.max_seq_length)
        else:
            input_data_t0 = np.concatenate([
                np.asarray(player_id_t0_batch),
                np.asarray(s_t0_batch),
                np.asarray(action_id_t0), previous_prior_mu_train_seq,
                previous_prior_sigma_train_seq, train_mask
            ],
                                           axis=2)
            target_data_t0 = np.asarray(np.asarray(player_id_t0_batch))
            trace_lengths_t0 = trace_t0_batch
            selection_matrix_t0 = generate_selection_matrix(
                trace_lengths_t0, max_trace_length=config.Learn.max_seq_length)

            input_data_t1 = np.concatenate([
                np.asarray(player_id_t1_batch),
                np.asarray(s_t1_batch),
                np.asarray(action_id_t1), previous_prior_mu_train_seq,
                previous_prior_sigma_train_seq, train_mask
            ],
                                           axis=2)
            target_data_t1 = np.asarray(np.asarray(player_id_t1_batch))
            trace_lengths_t1 = trace_t1_batch
            selection_matrix_t1 = generate_selection_matrix(
                trace_t1_batch, max_trace_length=config.Learn.max_seq_length)

        if training_flag:

            if config.Learn.apply_stochastic:
                for i in range(len(input_data_t0)):
                    MemoryBuffer.push([
                        input_data_t0[i], target_data_t0[i],
                        trace_lengths_t0[i], selection_matrix_t0[i],
                        input_data_t1[i], target_data_t1[i],
                        trace_lengths_t1[i], selection_matrix_t1[i],
                        r_t_batch[i], terminal_batch[i], cut_batch[i]
                    ])
                sampled_data = MemoryBuffer.sample(
                    batch_size=config.Learn.batch_size)
                sample_input_data_t0 = np.asarray(
                    [sampled_data[j][0] for j in range(len(sampled_data))])
                sample_target_data_t0 = np.asarray(
                    [sampled_data[j][1] for j in range(len(sampled_data))])
                sample_trace_lengths_t0 = np.asarray(
                    [sampled_data[j][2] for j in range(len(sampled_data))])
                sample_selection_matrix_t0 = np.asarray(
                    [sampled_data[j][3] for j in range(len(sampled_data))])
                # sample_input_data_t1 = np.asarray([sampled_data[j][4] for j in range(len(sampled_data))])
                # sample_target_data_t1 = np.asarray([sampled_data[j][5] for j in range(len(sampled_data))])
                # sample_trace_lengths_t1 = np.asarray([sampled_data[j][6] for j in range(len(sampled_data))])
                # sample_selection_matrix_t1 = np.asarray([sampled_data[j][7] for j in range(len(sampled_data))])
                # sample_r_t_batch = np.asarray([sampled_data[j][8] for j in range(len(sampled_data))])
                # sample_terminal_batch = np.asarray([sampled_data[j][9] for j in range(len(sampled_data))])
                # sample_cut_batch = np.asarray([sampled_data[j][10] for j in range(len(sampled_data))])
                pretrain_flag = False

                for i in range(0, len(terminal_batch)):
                    batch_terminal = terminal_batch[i]
                    batch_cut = cut_batch[i]
            else:
                sample_input_data_t0 = input_data_t0
                sample_target_data_t0 = target_data_t0
                sample_trace_lengths_t0 = trace_lengths_t0
                sample_selection_matrix_t0 = selection_matrix_t0

            train_cvrnn_model(model, sess, config, sample_input_data_t0,
                              sample_target_data_t0, sample_trace_lengths_t0,
                              sample_selection_matrix_t0, pretrain_flag)

            # train_td_model(model, sess, config, input_data_t0, trace_lengths_t0, selection_matrix_t0,
            #                input_data_t1, trace_lengths_t1, selection_matrix_t1, r_t_batch, terminal, cut)

        else:
            if validate_cvrnn_flag:
                output_decoder = cvrnn_validation(sess, model, input_data_t0,
                                                  target_data_t0,
                                                  trace_lengths_t0,
                                                  selection_matrix_t0, config)

                if output_decoder_all is None:
                    output_decoder_all = output_decoder
                    target_data_all = target_data_t0
                    selection_matrix_all = selection_matrix_t0
                else:
                    # try:
                    output_decoder_all = np.concatenate(
                        [output_decoder_all, output_decoder], axis=0)
                    # except:
                    #     print output_decoder_all.shape
                    #     print  output_decoder.shape
                    target_data_all = np.concatenate(
                        [target_data_all, target_data_t0], axis=0)
                    selection_matrix_all = np.concatenate(
                        [selection_matrix_all, selection_matrix_t0], axis=0)

            # if validate_td_flag:
            #     # validate_variance_flag = validate_variance_flag if train_number <= 500 else False
            #     q_values, match_q_values_players_dict = \
            #         td_validation(sess, model, trace_lengths_t0, selection_matrix_t0,
            #                       player_id_t0_batch, s_t0_batch, action_id_t0, input_data_t0,
            #                       train_mask, config, match_q_values_players_dict,
            #                       r_t_batch, terminal, cut, train_number, validate_variance_flag=False)
            #
            #     if q_values_all is None:
            #         q_values_all = q_values
            #     else:
            #         q_values_all = np.concatenate([q_values_all, q_values], axis=0)

        s_t0 = s_t1
        if terminal:
            break

    return [
        output_decoder_all, target_data_all, selection_matrix_all,
        q_values_all, match_q_values_players_dict
    ]
def gathering_running_and_run(dir_game, config, player_id_cluster_dir, data_store, source_data_dir,
                              model, sess, training_flag, game_number,
                              validate_cvrnn_flag=False,
                              validate_td_flag=False,
                              validate_diff_flag=False,
                              validate_variance_flag=False,
                              validate_predict_flag=False,
                              output_decoder_all=None,
                              target_data_all=None,
                              q_values_all=None,
                              output_label_all=None,
                              real_label_all=None,
                              pred_target_data_all=None,
                              pred_output_prob_all=None
                              ):
    if validate_variance_flag:
        match_q_values_players_dict = {}
        for i in range(config.Learn.player_cluster_number):
            match_q_values_players_dict.update({i: []})
    else:
        match_q_values_players_dict = None

    state_trace_length, state_input, reward, actions, team_id, player_index = get_icehockey_game_data(
        data_store=data_store, dir_game=dir_game, config=config, player_id_cluster_dir=player_id_cluster_dir)

    if config.Learn.apply_lstm:
        # raise ValueError('no! do not use LSTM!')
        encoder_trace = state_trace_length
        encoder_state_input = state_input
        player_index_seq = transfer2seq(data=player_index, trace_length=state_trace_length,
                                        max_length=config.Learn.max_seq_length)
        encoder_player_index = player_index_seq
        action_seq = transfer2seq(data=actions, trace_length=state_trace_length,
                                  max_length=config.Learn.max_seq_length)
        encoder_action = action_seq
        team_id_seq = transfer2seq(data=team_id, trace_length=state_trace_length,
                                   max_length=config.Learn.max_seq_length)
        encoder_team_id = team_id_seq
        axis_concat = 2

    else:
        state_zero_trace = [1] * len(state_trace_length)
        state_zero_input = []
        for trace_index in range(0, len(state_trace_length)):
            trace_length = state_trace_length[trace_index]
            trace_length = trace_length - 1
            if trace_length > 9:
                trace_length = 9
            state_zero_input.append(state_input[trace_index, trace_length, :])
        state_zero_input = np.asarray(state_zero_input)
        encoder_trace = state_zero_trace
        encoder_state_input = state_zero_input
        encoder_player_index = player_index
        encoder_action = actions
        encoder_team_id = team_id
        axis_concat = 1

    if config.Arch.Predict.predict_target == 'ActionGoal':
        add_pred_flag = True
        actions_all = read_feature_within_events(directory=dir_game,
                                                 data_path=source_data_dir,
                                                 feature_name='name')
        next_goal_label = []
        data_length = state_trace_length.shape[0]
        new_reward = []
        new_action = []
        new_state_input = []
        new_state_trace = []
        new_team_id = []
        new_player_index = []
        for action_index in range(0, data_length):
            action = actions_all[action_index]
            if 'shot' in action:
                if action_index + 1 == data_length:
                    continue
                if config.Learn.apply_lstm:
                    new_reward.append(reward[action_index])
                    new_action.append(encoder_action[action_index])
                    new_state_input.append(encoder_state_input[action_index])
                    new_state_trace.append(encoder_trace[action_index])
                    new_team_id.append(encoder_team_id[action_index])
                    new_player_index.append(encoder_player_index[action_index])
                else:
                    new_reward.append(reward[action_index])
                    new_action.append(encoder_action[action_index])
                    new_state_input.append(encoder_state_input[action_index])
                    new_state_trace.append(encoder_trace[action_index])
                    new_team_id.append(encoder_team_id[action_index])
                    new_player_index.append(encoder_player_index[action_index])
                if actions_all[action_index + 1] == 'goal':
                    # print(actions_all[action_index+1])
                    next_goal_label.append([1, 0])
                else:
                    # print(actions_all[action_index + 1])
                    next_goal_label.append([0, 1])
        pred_target = next_goal_label
    # elif config.Arch.Predict.predict_target == 'Action':
    #     add_pred_flag = True
    #     pred_target = actions[1:, :]
    #     new_reward = reward[:-1]
    #     new_action = actions[:-1, :, :]
    #     new_state_zero_input = state_zero_input[:-1, :, :]
    #     new_state_zero_trace = state_zero_trace[:-1]
    #     new_team_id = team_id[:-1, :, :]
    #     new_player_index = player_index[:-1, :, :]
    else:
        # raise ValueError()
        add_pred_flag = False

    if add_pred_flag:
        # if training_flag:
        #     pred_train_mask = np.asarray([1] * len(new_state_zero_input))
        # else:
        #     pred_train_mask = np.asarray([1] * len(new_state_zero_input))  # we should apply encoder output
        if config.Learn.predict_target == 'PlayerLocalId':
            pred_input_data = np.concatenate([np.asarray(new_player_index),
                                              np.asarray(new_team_id),
                                              np.asarray(new_state_input),
                                              np.asarray(new_action), ], axis=axis_concat)
            pred_target_data = np.asarray(np.asarray(pred_target))
            pred_trace_lengths = new_state_trace
        else:
            pred_input_data = np.concatenate([np.asarray(new_player_index),
                                              np.asarray(new_state_input),
                                              np.asarray(new_action)], axis=axis_concat)
            pred_target_data = np.asarray(np.asarray(pred_target))
            pred_trace_lengths = new_state_trace
        if training_flag:
            for i in range(len(new_state_input)):
                cache_label = np.argmax(pred_target_data[i], axis=0)
                Prediction_MemoryBuffer.push([pred_input_data[i],
                                              pred_target_data[i],
                                              pred_trace_lengths[i],
                                              # pred_train_mask[i]
                                              ], cache_label)

    score_diff = compute_game_score_diff_vec(rewards=reward)
    score_difference_game = read_feature_within_events(dir_game,
                                                       source_data_dir,
                                                       'scoreDifferential',
                                                       transfer_home_number=True,
                                                       data_store=data_store)

    # reward_count = sum(reward)
    # print ("reward number" + str(reward_count))
    if len(state_input) != len(reward) or len(state_trace_length) != len(reward):
        raise Exception('state length does not equal to reward length')

    train_len = len(state_input)
    train_number = 0
    s_t0 = encoder_state_input[train_number]
    train_number += 1
    while True:
        # try:
        batch_return, \
        train_number, \
        s_tl, \
        print_flag = get_together_training_batch(s_t0=s_t0,
                                                 state_input=encoder_state_input,
                                                 reward=reward,
                                                 player_index=encoder_player_index,
                                                 train_number=train_number,
                                                 train_len=train_len,
                                                 state_trace_length=encoder_trace,
                                                 action=encoder_action,
                                                 team_id=encoder_team_id,
                                                 win_info=score_diff,
                                                 score_info=score_difference_game,
                                                 config=config)

        # get the batch variables
        # s_t0, s_t1, r_t0_combine, s_length_t0, s_length_t1, action_id_t0, action_id_t1, team_id_t0,
        #                      team_id_t1, 0, 0
        s_t0_batch = [d[0] for d in batch_return]
        s_t1_batch = [d[1] for d in batch_return]
        r_t_batch = [d[2] for d in batch_return]
        trace_t0_batch = [d[3] for d in batch_return]
        trace_t1_batch = [d[4] for d in batch_return]
        action_id_t0 = [d[5] for d in batch_return]
        action_id_t1 = [d[6] for d in batch_return]
        team_id_t0_batch = [d[7] for d in batch_return]
        team_id_t1_batch = [d[8] for d in batch_return]
        player_id_t0_batch = [d[9] for d in batch_return]
        player_id_t1_batch = [d[10] for d in batch_return]
        win_id_t_batch = [d[11] for d in batch_return]
        terminal_batch = [d[-2] for d in batch_return]
        cut_batch = [d[-1] for d in batch_return]
        score_diff_t_batch = [d[11] for d in batch_return]
        score_diff_base_t0_batch = [d[12] for d in batch_return]
        outcome_data = score_diff_t_batch
        score_diff_base_t0 = score_diff_base_t0_batch
        # (player_id, state ,action flag)

        if training_flag:
            train_mask = np.asarray([1] * len(s_t0_batch))
        else:
            train_mask = np.asarray([0] * len(s_t0_batch))

        for i in range(0, len(terminal_batch)):
            terminal = terminal_batch[i]
            cut = cut_batch[i]
        if config.Learn.apply_lstm:
            target_data_t0 = []
            for trace_index in range(0, len(trace_t0_batch)):
                trace_length = trace_t0_batch[trace_index]
                trace_length = trace_length - 1
                target_data_t0.append(player_id_t0_batch[trace_index][trace_length, :])
            target_data_t0 = np.asarray(target_data_t0)

            target_data_t1 = []
            for trace_index in range(0, len(trace_t1_batch)):
                trace_length = trace_t1_batch[trace_index]
                trace_length = trace_length - 1
                target_data_t1.append(player_id_t1_batch[trace_index][trace_length, :])
            target_data_t1 = np.asarray(target_data_t1)

        else:
            target_data_t0 = np.asarray(player_id_t0_batch)
            target_data_t1 = np.asarray(player_id_t1_batch)
        if config.Learn.predict_target == 'PlayerLocalId':
            input_data_t0 = np.concatenate([np.asarray(player_id_t0_batch),
                                            np.asarray(team_id_t0_batch),
                                            np.asarray(s_t0_batch),
                                            np.asarray(action_id_t0)], axis=axis_concat)
            # target_data_t0 = np.asarray(np.asarray(player_id_t0_batch))
            input_data_t1 = np.concatenate([np.asarray(player_id_t1_batch),
                                            np.asarray(team_id_t1_batch),
                                            np.asarray(s_t1_batch),
                                            np.asarray(action_id_t1)], axis=axis_concat)
            # target_data_t1 = np.asarray(np.asarray(player_id_t1_batch))
            trace_length_t0 = np.asarray(trace_t0_batch)
            trace_length_t1 = np.asarray(trace_t1_batch)
        else:
            input_data_t0 = np.concatenate([np.asarray(player_id_t0_batch), np.asarray(s_t0_batch),
                                            np.asarray(action_id_t0)], axis=axis_concat)
            # target_data_t0 = np.asarray(np.asarray(player_id_t0_batch))
            input_data_t1 = np.concatenate([np.asarray(player_id_t1_batch), np.asarray(s_t1_batch),
                                            np.asarray(action_id_t1)], axis=axis_concat)
            # target_data_t1 = np.asarray(np.asarray(player_id_t1_batch))
            trace_length_t0 = np.asarray(trace_t0_batch)
            trace_length_t1 = np.asarray(trace_t1_batch)

        if training_flag:

            if config.Learn.apply_stochastic:
                for i in range(len(input_data_t0)):
                    MemoryBuffer.push([input_data_t0[i], target_data_t0[i], trace_length_t0[i],
                                       input_data_t1[i], target_data_t1[i], trace_length_t1[i],
                                       r_t_batch[i], win_id_t_batch[i],
                                       terminal_batch[i], cut_batch[i]
                                       ])
                sampled_data = MemoryBuffer.sample(batch_size=config.Learn.batch_size)
                sample_input_data_t0 = np.asarray([sampled_data[j][0] for j in range(len(sampled_data))])
                sample_target_data_t0 = np.asarray([sampled_data[j][1] for j in range(len(sampled_data))])
                sample_trace_length_t0 = np.asarray([sampled_data[j][2] for j in range(len(sampled_data))])
                if training_flag:
                    train_mask = np.asarray([1] * len(sample_input_data_t0))
                else:
                    train_mask = np.asarray([0] * len(sample_input_data_t0))

                for i in range(0, len(terminal_batch)):
                    batch_terminal = terminal_batch[i]
                    batch_cut = cut_batch[i]
            else:
                sample_input_data_t0 = input_data_t0
                sample_target_data_t0 = target_data_t0
                sample_trace_length_t0 = trace_length_t0

            likelihood_loss = train_encoder_model(model, sess, config, sample_input_data_t0,
                                                  sample_target_data_t0, sample_trace_length_t0,
                                                  train_mask)
            if terminal or cut:
                print('\n ll lost is {0}'.format(str(likelihood_loss)))

            """we skip sampling for TD learning"""
            train_td_model(model, sess, config,
                           input_data_t0, target_data_t0, trace_length_t0,
                           input_data_t1, target_data_t1, trace_length_t1,
                           r_t_batch, train_mask, terminal, cut)

            train_score_diff(model, sess, config,
                             input_data_t0, target_data_t0, trace_length_t0,
                             input_data_t1, target_data_t1, trace_length_t1,
                             r_t_batch, score_diff_base_t0, outcome_data,
                             train_mask, terminal, cut)

            if add_pred_flag:
                sampled_data = Prediction_MemoryBuffer.sample(batch_size=config.Learn.batch_size)
                sample_pred_input_data = np.asarray([sampled_data[j][0] for j in range(len(sampled_data))])
                sample_pred_target_data = np.asarray([sampled_data[j][1] for j in range(len(sampled_data))])
                sample_pred_trace_lengths = np.asarray([sampled_data[j][2] for j in range(len(sampled_data))])
                # sample_pred_train_mask = np.asarray([sampled_data[j][2] for j in range(len(sampled_data))])

                train_prediction(model, sess, config,
                                 sample_pred_input_data,
                                 sample_pred_target_data,
                                 sample_pred_trace_lengths)

        else:
            # for i in range(0, len(r_t_batch)):
            #     if i == len(r_t_batch) - 1:
            #         if terminal or cut:
            #             print(r_t_batch[i])
            if validate_cvrnn_flag:
                output_decoder = encoder_validation(sess, model, input_data_t0, target_data_t0,
                                                    trace_length_t0, train_mask, config)

                if output_decoder_all is None:
                    output_decoder_all = output_decoder
                    target_data_all = target_data_t0
                else:
                    # try:
                    output_decoder_all = np.concatenate([output_decoder_all, output_decoder], axis=0)
                    target_data_all = np.concatenate([target_data_all, target_data_t0], axis=0)

            if validate_td_flag:
                # validate_variance_flag = validate_variance_flag if train_number <= 500 else False
                q_values = td_validation(sess, model, input_data_t0, target_data_t0, trace_length_t0, train_mask,
                                         config)

                if q_values_all is None:
                    q_values_all = q_values
                else:
                    q_values_all = np.concatenate([q_values_all, q_values], axis=0)

            if validate_diff_flag:
                output_label, real_label = diff_validation(sess, model, input_data_t0, target_data_t0,
                                                           trace_length_t0, train_mask,
                                                           score_diff_base_t0,
                                                           config, outcome_data)
                if real_label_all is None:
                    real_label_all = real_label
                else:
                    real_label_all = np.concatenate([real_label_all, real_label], axis=0)

                if output_label_all is None:
                    output_label_all = output_label
                else:
                    output_label_all = np.concatenate([output_label_all, output_label, ], axis=0)

        s_t0 = s_tl
        if terminal:
            break

    if validate_predict_flag and add_pred_flag:
        pred_target_data, pred_output_prob = validate_prediction(model, sess, config,
                                                                 pred_input_data,
                                                                 pred_target_data,
                                                                 pred_trace_lengths)
        if pred_target_data_all is None:
            pred_target_data_all = pred_target_data
        else:
            pred_target_data_all = np.concatenate([pred_target_data_all, pred_target_data], axis=0)

        if pred_output_prob_all is None:
            pred_output_prob_all = pred_output_prob
        else:
            pred_output_prob_all = np.concatenate([pred_output_prob_all, pred_output_prob], axis=0)
    return [output_decoder_all, target_data_all,
            q_values_all, real_label_all, output_label_all, pred_target_data_all, pred_output_prob_all]