def test_rollout_adjustment(key):
    """
    Train the agent on a state abstraction with fatal errors. Then generate a roll-out, detach the first state that's
    part of a cycle, and restart learning.
    """
    # Load a poorly-performing abstraction
    names = ['AbstrType', 'AbstrEps', 'CorrType', 'CorrProp', 'Batch', 'Dict']
    df = pd.read_csv('../abstr_exp/corrupted/corrupted_abstractions.csv', names=names)
    abstr_string = df.loc[(df['AbstrType'] == str(key[0]))
                        & (df['AbstrEps'] == key[1])
                        & (df['CorrType'] == str(key[2]))
                        & (df['CorrProp'] == key[3])
                        & (df['Batch'] == key[4])]['Dict'].values[0]
    abstr_list = ast.literal_eval(abstr_string)
    abstr_dict = {}
    for el in abstr_list:
        is_term = el[0][0] == 11 and el[0][1] == 11
        state = GridWorldState(el[0][0], el[0][1], is_terminal=is_term)
        abstr_dict[state] = el[1]

    # Create an agent with this abstraction
    s_a = StateAbstraction(abstr_dict, abstr_type=Abstr_type.PI_STAR)
    mdp = GridWorldMDP()
    agent = AbstractionAgent(mdp, s_a=s_a)

    # This is useful for later
    agent2 = copy.deepcopy(agent)

    # Generate a roll-out from a trained agent after 10000 steps
    for i in range(5000):
        agent.explore()
    rollout = agent.generate_rollout()
    print('Roll-out for model with no adjustment, 5,000 steps')
    for state in rollout:
        print(state, end=', ')
    for i in range(5000):
        agent.explore()
    rollout = agent.generate_rollout()
    print('Roll-out for model with no adjustment, 10,000 steps')
    for state in rollout:
        print(state, end=', ')
    print('\n')

    # Train an agent for 5000 steps, detach the first state in the cycle, and train for another 5000 steps
    #  The hope is that this will get further than the 10000 step one
    for i in range(5000):
        agent2.explore()
    rollout = agent2.generate_rollout()
    print('Roll-out for model pre-adjustment, 5,000 steps')
    for state in rollout:
        print(state, end=', ')
    print()
    print('Detaching state', rollout[-1])
    agent2.detach_state(rollout[-1])
    for i in range(5000):
        agent2.explore()
    rollout = agent2.generate_rollout()
    print('Roll-out for model post-adjustment, 10,000 steps')
    for state in rollout:
        print(state, end=', ')
def iterate_detachment(mdp_key, batch_size=5000):
    """
    Load an incorrect abstraction. Train the model, generate a roll-out, detach the first cycle state. Repeat until
    the roll-out achieves a terminal state. Save the adjusted abstraction and learned policy. Visualize the original
    incorrect abstraction with roll-outs from original agents and the adjusted abstraction with a roll-out from the
    new agent
    :param key: key for incorrect (poorly performing) abstraction
    :param batch_size: Number of steps to train between state detachments
    """
    # Load a poorly-performing abstraction
    names = ['AbstrType', 'AbstrEps', 'CorrType', 'CorrProp', 'Batch', 'Dict']
    df = pd.read_csv('../abstr_exp/corrupted/corrupted_abstractions.csv', names=names)
    abstr_string = df.loc[(df['AbstrType'] == str(mdp_key[0]))
                        & (df['AbstrEps'] == mdp_key[1])
                        & (df['CorrType'] == str(mdp_key[2]))
                        & (df['CorrProp'] == mdp_key[3])
                        & (df['Batch'] == mdp_key[4])]['Dict'].values[0]
    abstr_list = ast.literal_eval(abstr_string)
    abstr_dict = {}
    for el in abstr_list:
        is_term = el[0][0] == 11 and el[0][1] == 11
        state = GridWorldState(el[0][0], el[0][1], is_terminal=is_term)
        abstr_dict[state] = el[1]

    # Create an agent with this abstraction
    s_a = StateAbstraction(abstr_dict, abstr_type=Abstr_type.PI_STAR)
    mdp = GridWorldMDP()
    agent = AbstractionAgent(mdp, s_a=s_a)

    # Generate a roll-out from untrained model (should be random and short)
    rollout = agent.generate_rollout()
    print('Roll-out from untrained model')
    for state in rollout:
        print(state, end=', ')
    print()

    # Until roll-out leads to terminal state, explore and detach last state of roll-out. Record each of the detached
    #  states so they can be visualized later
    detached_states = []
    step_counter = 0
    while not rollout[-1].is_terminal():
        for i in range(batch_size):
            agent.explore()
        step_counter += batch_size
        rollout = agent.generate_rollout()
        print('Roll-out after', step_counter, 'steps')
        for state in rollout:
            print(state, end=', ')
        print()
        print('State Q-value pre-detach:')
        for action in agent.mdp.actions:
            print(rollout[-1], action, agent.get_q_value(rollout[-1], action))
        detach_flag = agent.detach_state(rollout[-1])
        if detach_flag == 0:
            print('Detaching state', rollout[-1])
            detached_states.append(rollout[-1])
        elif detach_flag == 1:
            print(rollout[-1], 'already a singleton state. No change.')
        print('State Q-value post-detach:')
        for action in agent.mdp.actions:
            print(rollout[-1], action, agent.get_q_value(rollout[-1], action))
        print()
    for key, value in agent.get_q_table():
        print(key, value)

    # Save resulting adapted state abstraction and learned policy
    s_a_file = open('../abstr_exp/adapted/adapted_abstraction.csv', 'w', newline='')
    s_a_writer = csv.writer(s_a_file)
    print(mdp_key)
    s_a_writer.writerow((mdp_key[0], mdp_key[1], mdp_key[2], mdp_key[3], mdp_key[4], agent.get_abstraction_as_string()))
    s_a_file.close()

    policy_file = open('../abstr_exp/adapted/learned_policy.csv', 'w', newline='')
    policy_writer = csv.writer(policy_file)
    policy_writer.writerow((mdp_key[0], mdp_key[1], mdp_key[2], mdp_key[3], mdp_key[4],
                            agent.get_learned_policy_as_string()))
    policy_file.close()

    # Visualize the adapted state abstraction and learned policy, along with the original for comparison
    viz = GridWorldVisualizer()
    surface = viz.create_corruption_visualization(mdp_key,
                                                  '../abstr_exp/adapted/adapted_abstraction.csv',
                                                  error_file='../abstr_exp/corrupted/error_states.csv')
    # Draw small white circles over the states that were detached
    for state in detached_states:
        print(state, end=', ')
    #for d_state in
    viz.display_surface(surface)