Ejemplo n.º 1
0
    def __init__(self, rank):
        docker_client = docker.from_env()
        agent_port, partner_port = 10000 + rank, 20000 + rank
        clients = [('127.0.0.1', agent_port), ('127.0.0.1', partner_port)]
        self.agent_type = GlobalVar()

        # Assume Minecraft launched if port has listener, launch otherwise
        if not _port_has_listener(agent_port):
            self._launch_malmo(docker_client, agent_port)
        print('Malmo running on port ' + str(agent_port))
        if not _port_has_listener(partner_port):
            self._launch_malmo(docker_client, partner_port)
        print('Malmo running on port ' + str(partner_port))

        # Set up partner agent env in separate process
        p = mp.Process(target=self._run_partner, args=(clients, ))
        p.daemon = True
        p.start()
        time.sleep(3)

        # Set up agent env
        self.env = PigChaseEnvironment(clients,
                                       PigChaseTopDownStateBuilder(gray=False),
                                       role=1,
                                       randomize_positions=True)
Ejemplo n.º 2
0
 def _run_partner(self, clients):
     env = PigChaseEnvironment(clients,
                               PigChaseSymbolicStateBuilder(),
                               role=0,
                               randomize_positions=True)
     agent = PigChaseChallengeAgent(ENV_AGENT_NAMES[0])
     self.agent_type.set(
         type(agent.current_agent) == RandomAgent
         and PigChaseEnvironment.AGENT_TYPE_1
         or PigChaseEnvironment.AGENT_TYPE_2)
     obs = env.reset(self.agent_type)
     reward = 0
     agent_done = False
     while True:
         # Select an action
         action = agent.act(obs, reward, agent_done, is_training=True)
         # Reset if needed
         if env.done:
             self.agent_type.set(
                 type(agent.current_agent) == RandomAgent
                 and PigChaseEnvironment.AGENT_TYPE_1
                 or PigChaseEnvironment.AGENT_TYPE_2)
             obs = env.reset(self.agent_type)
         # Take a step
         obs, reward, agent_done = env.do(action)
Ejemplo n.º 3
0
def agent_factory(name, role, kind, clients, max_episodes, max_actions, logdir,
                  quit):
    assert len(
        clients
    ) >= 2, 'There are not enough Malmo clients in the pool (need at least 2)'

    clients = parse_clients_args(clients)
    visualizer = ConsoleVisualizer(prefix='Agent %d' % role)

    if role == 0:
        env = PigChaseEnvironment(clients,
                                  PigChaseSymbolicStateBuilder(),
                                  actions=ENV_ACTIONS,
                                  role=role,
                                  human_speed=True,
                                  randomize_positions=True)
        agent = PigChaseChallengeAgent(name)

        if type(agent.current_agent) == RandomAgent:
            agent_type = PigChaseEnvironment.AGENT_TYPE_1
        else:
            agent_type = PigChaseEnvironment.AGENT_TYPE_2
        obs = env.reset(agent_type)
        reward = 0
        rewards = []
        done = False
        episode = 0

        while True:

            # select an action
            action = agent.act(obs, reward, done, True)

            if done:
                visualizer << (episode + 1, 'Reward', sum(rewards))
                rewards = []
                episode += 1

                if type(agent.current_agent) == RandomAgent:
                    agent_type = PigChaseEnvironment.AGENT_TYPE_1
                else:
                    agent_type = PigChaseEnvironment.AGENT_TYPE_2
                obs = env.reset(agent_type)

            # take a step
            obs, reward, done = env.do(action)
            rewards.append(reward)

    else:
        env = PigChaseEnvironment(clients,
                                  PigChaseSymbolicStateBuilder(),
                                  actions=list(ARROW_KEYS_MAPPING.values()),
                                  role=role,
                                  randomize_positions=True)
        env.reset(PigChaseEnvironment.AGENT_TYPE_3)

        agent = PigChaseHumanAgent(name, env, list(ARROW_KEYS_MAPPING.keys()),
                                   max_episodes, max_actions, visualizer, quit)
        agent.show()
Ejemplo n.º 4
0
    def run(self):
        from multiprocessing import Process

        env = PigChaseEnvironment(self._clients,
                                  self._state_builder,
                                  actions=DanishPuppet.ACTIONS.all_commands(),
                                  role=1,
                                  randomize_positions=True)
        print('==================================')
        print('Starting evaluation of Agent @100k')

        p = Process(target=run_challenge_agent, args=(self._clients, ))
        p.start()
        sleep(5)
        agent_loop(self._agent_100k, env, self._accumulators['100k'])
        p.terminate()

        print('==================================')
        print('Starting evaluation of Agent @500k')

        p = Process(target=run_challenge_agent, args=(self._clients, ))
        p.start()
        sleep(5)
        agent_loop(self._agent_500k, env, self._accumulators['500k'])
        p.terminate()
Ejemplo n.º 5
0
def run_challenge_agent(clients):
    builder = PigChaseSymbolicStateBuilder()
    env = PigChaseEnvironment(clients,
                              builder,
                              role=0,
                              randomize_positions=True)
    agent = PigChaseChallengeAgent(ENV_AGENT_NAMES[0])
    agent_loop(agent, env, None)
Ejemplo n.º 6
0
def agent_factory(name, role, type, clients, max_epochs, logdir, visualizer):

    assert len(clients) >= 2, 'Not enough clients (need at least 2)'
    clients = parse_clients_args(clients)

    builder = PigChaseSymbolicStateBuilder()
    env = PigChaseEnvironment(clients,
                              builder,
                              role=role,
                              randomize_positions=True)

    if role == 0:
        agent = PigChaseChallengeAgent(name)

        obs = env.reset()
        reward = 0
        agent_done = False

        while True:
            if env.done:
                obs = env.reset()

            # select an action
            action = agent.act(obs, reward, agent_done, is_training=True)
            # take a step
            obs, reward, agent_done = env.do(action)

    else:

        if type == 'astar':
            agent = FocusedAgent(name, ENV_TARGET_NAMES[0])
        else:
            agent = RandomAgent(name, env.available_actions)

        obs = env.reset()
        reward = 0
        agent_done = False
        viz_rewards = []

        max_training_steps = EPOCH_SIZE * max_epochs
        for step in range(1, max_training_steps + 1):

            # check if env needs reset
            if env.done:

                visualize_training(visualizer, step, viz_rewards)
                viz_rewards = []
                obs = env.reset()

            # select an action
            action = agent.act(obs, reward, agent_done, is_training=True)
            # take a step
            obs, reward, agent_done = env.do(action)
            viz_rewards.append(reward)

            agent.inject_summaries(step)
Ejemplo n.º 7
0
    def run(self):
        from multiprocessing import Process

        env = PigChaseEnvironment(self._clients, self._state_builder,
                                  role=1, randomize_positions=True)
        print('==================================')
        print('Starting evaluation of Agent provided')
        ##Initialize the threads with the two agents
        p = Process(target=run_challenge_agent, args=(self._clients,True))
        p.start()
        sleep(5)
        agent_loop(self._agent_100k, env, self._accumulators['100k'], True)
        self._agent_100k.save('pesosBuenos.h5') ##This saves the weights of the neural network when the agent has finished its training
        p.terminate()
Ejemplo n.º 8
0
def agent_factory(name, role, type, clients, max_epochs, logdir, visualizer):

    assert len(clients) >= 2, 'Not enough clients (need at least 2)'
    clients = parse_clients_args(clients)

    builder = PigChaseSymbolicStateBuilder()
    env = PigChaseEnvironment(clients,
                              builder,
                              role=role,
                              randomize_positions=True)

    if role == 0:
        agent = FocusedAgent(name, ENV_TARGET_NAMES[0])

        obs = env.reset()
        reward = 0
        agent_done = False

        while True:
            if env.done:
                obs = env.reset()

            # select an action
            action = agent.act(obs, reward, agent_done, is_training=True)
            # take a step
            obs, reward, agent_done = env.do(action)

    else:
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            agent = BayesAgent(name, ENV_TARGET_NAMES[0], 'Agent_1', False,
                               sess)
            if not agent.save:
                sess.run(tf.global_variables_initializer())
                print "Initialize"
            obs = env.reset()
            agent.reset(obs)
            reward = 0
            agent_done = False
            viz_rewards = []
            avg = []
            max_training_steps = EPOCH_SIZE * max_epochs
            epoch = 0
            for step in range(1, max_training_steps + 1):
                # check if env needs reset
                if env.done:
                    avg.append(sum(viz_rewards))
                    print "Epoch:%d, accumulative rewards: %d" % (
                        epoch, sum(viz_rewards))
                    visualize_training(visualizer, step, viz_rewards)
                    viz_rewards = []
                    obs = env.reset()
                    agent.reset(obs)
                    epoch += 1

                # select an action
                action = agent.act(obs, reward, agent_done, is_training=False)
                obs, reward, agent_done = env.do(action)
                viz_rewards.append(reward)
                #
                agent.inject_summaries(step)
            print "Average Reward: ", 1. * sum(avg) / len(avg)
Ejemplo n.º 9
0
def agent_factory(name, role, clients, backend, device, max_epochs, logdir,
                  visualizer):

    assert len(clients) >= 2, 'Not enough clients (need at least 2)'
    clients = parse_clients_args(clients)

    if role == 0:
        builder = PigChaseSymbolicStateBuilder()
        env = PigChaseEnvironment(clients,
                                  builder,
                                  role=role,
                                  randomize_positions=True)

        agent = PigChaseChallengeAgent(name)
        if type(agent.current_agent) == RandomAgent:
            agent_type = PigChaseEnvironment.AGENT_TYPE_1
        else:
            agent_type = PigChaseEnvironment.AGENT_TYPE_2

        obs = env.reset(agent_type)
        reward = 0
        agent_done = False

        while True:
            if env.done:
                if type(agent.current_agent) == RandomAgent:
                    agent_type = PigChaseEnvironment.AGENT_TYPE_1
                else:
                    agent_type = PigChaseEnvironment.AGENT_TYPE_2

                obs = env.reset(agent_type)
                while obs is None:
                    # this can happen if the episode ended with the first
                    # action of the other agent
                    print('Warning: received obs == None.')
                    obs = env.reset(agent_type)

            # select an action
            action = agent.act(obs, reward, agent_done, is_training=True)
            # take a step
            obs, reward, agent_done = env.do(action)
    else:
        env = PigChaseEnvironment(clients,
                                  MalmoALEStateBuilder(),
                                  role=role,
                                  randomize_positions=True)
        memory = TemporalMemory(100000, (84, 84))

        if backend == 'cntk':
            from malmopy.model.cntk import QNeuralNetwork
            model = QNeuralNetwork((memory.history_length, 84, 84),
                                   env.available_actions, device)
        else:
            from malmopy.model.chainer import QNeuralNetwork, DQNChain
            chain = DQNChain((memory.history_length, 84, 84),
                             env.available_actions)
            target_chain = DQNChain((memory.history_length, 84, 84),
                                    env.available_actions)
            model = QNeuralNetwork(chain, target_chain, device)

        explorer = LinearEpsilonGreedyExplorer(1, 0.1, 1000000)
        agent = PigChaseQLearnerAgent(name,
                                      env.available_actions,
                                      model,
                                      memory,
                                      0.99,
                                      32,
                                      50000,
                                      explorer=explorer,
                                      visualizer=visualizer)

        obs = env.reset()
        reward = 0
        agent_done = False
        viz_rewards = []

        max_training_steps = EPOCH_SIZE * max_epochs
        for step in six.moves.range(1, max_training_steps + 1):
            # check if env needs reset
            if env.done:
                visualize_training(visualizer, step, viz_rewards)
                agent.inject_summaries(step)
                viz_rewards = []

                obs = env.reset()
                while obs is None:
                    # this can happen if the episode ended with the first
                    # action of the other agent
                    print('Warning: received obs == None.')
                    obs = env.reset()

            # select an action
            action = agent.act(obs, reward, agent_done, is_training=True)
            # take a step
            obs, reward, agent_done = env.do(action)
            viz_rewards.append(reward)

            if (step % EPOCH_SIZE) == 0:
                if 'model' in locals():
                    model.save('pig_chase-dqn_%d.model' % (step / EPOCH_SIZE))
Ejemplo n.º 10
0
class Env():
    def __init__(self, rank):
        docker_client = docker.from_env()
        agent_port, partner_port = 10000 + rank, 20000 + rank
        clients = [('127.0.0.1', agent_port), ('127.0.0.1', partner_port)]
        self.agent_type = GlobalVar()

        # Assume Minecraft launched if port has listener, launch otherwise
        if not _port_has_listener(agent_port):
            self._launch_malmo(docker_client, agent_port)
        print('Malmo running on port ' + str(agent_port))
        if not _port_has_listener(partner_port):
            self._launch_malmo(docker_client, partner_port)
        print('Malmo running on port ' + str(partner_port))

        # Set up partner agent env in separate process
        p = mp.Process(target=self._run_partner, args=(clients, ))
        p.daemon = True
        p.start()
        time.sleep(3)

        # Set up agent env
        self.env = PigChaseEnvironment(clients,
                                       PigChaseTopDownStateBuilder(gray=False),
                                       role=1,
                                       randomize_positions=True)

    def get_class_label(self):
        return self.agent_type.value() - 1

    def reset(self):
        observation = self.env.reset()
        while observation is None:  # May happen if episode ended with first action of other agent
            observation = self.env.reset()
        return _map_to_observation(observation)

    def step(self, action):
        observation, reward, done = self.env.do(action)
        return _map_to_observation(
            observation), reward, done, None  # Do not return any extra info

    def close(self):
        return  # TODO: Kill processes + Docker containers

    def _launch_malmo(self, client, port):
        # Launch Docker container
        client.containers.run('malmo',
                              '-port ' + str(port),
                              detach=True,
                              network_mode='host')
        # Check for port to come up
        launched = False
        for _ in range(100):
            time.sleep(3)
            if _port_has_listener(port):
                launched = True
                break
        # Quit if Malmo could not be launched
        if not launched:
            exit(1)

    # Runs partner in separate env
    def _run_partner(self, clients):
        env = PigChaseEnvironment(clients,
                                  PigChaseSymbolicStateBuilder(),
                                  role=0,
                                  randomize_positions=True)
        agent = PigChaseChallengeAgent(ENV_AGENT_NAMES[0])
        self.agent_type.set(
            type(agent.current_agent) == RandomAgent
            and PigChaseEnvironment.AGENT_TYPE_1
            or PigChaseEnvironment.AGENT_TYPE_2)
        obs = env.reset(self.agent_type)
        reward = 0
        agent_done = False
        while True:
            # Select an action
            action = agent.act(obs, reward, agent_done, is_training=True)
            # Reset if needed
            if env.done:
                self.agent_type.set(
                    type(agent.current_agent) == RandomAgent
                    and PigChaseEnvironment.AGENT_TYPE_1
                    or PigChaseEnvironment.AGENT_TYPE_2)
                obs = env.reset(self.agent_type)
            # Take a step
            obs, reward, agent_done = env.do(action)
Ejemplo n.º 11
0
def agent_factory(name, role, type, clients, max_epochs, logdir, visualizer):

    assert len(clients) >= 2, 'Not enough clients (need at least 2)'
    clients = parse_clients_args(clients)
    
    builder = PigChaseSymbolicStateBuilder()
    env = PigChaseEnvironment(clients, builder, role=role,
                              randomize_positions=True)

    if role == 0:
        agent1 = FocusedAgent(name, ENV_TARGET_NAMES[0])
        agent2 = RandomAgent(name, env.available_actions)
        agent3 = BadAgent(name)
        
        agent_list = [agent1, agent2, agent3]# three types of agent
        agent = agent1
        
        obs = env.reset()
        reward = 0
        agent_done = False
        max_training_steps = EPOCH_SIZE * max_epochs
        epoch = 0
        for step in range(1, max_training_steps+1):
            if env.done:
                obs = env.reset()
                epoch += 1
                agent = agent_list[epoch/10 % 3]# change for every 10 episodes
            # select an action
            action = agent.act(obs, reward, agent_done)
            # take a step
            obs, reward, agent_done = env.do(action)
                


    else:
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.Session(config = config) as sess:
            agent1 = BayesAgent(name, ENV_TARGET_NAMES[0], 'Agent_1', True, sess)
            agent2 = RandomAgent(name, env.available_actions)
            agent3 = BadAgent(name)
            agent4 = FocusedAgent(name, ENV_TARGET_NAMES[0])
            if not agent1.save:
                sess.run(tf.global_variables_initializer()) 
                print "Initialize"
            agent_list = [agent1, agent2, agent3, agent4]# three types of agents
            agent = agent1
            obs = env.reset()
            agent1.reset(obs)
            reward = 0
            agent_done = False
            viz_rewards = []
            avg = []
            epoch = 0
            s = 1
            max_training_steps = EPOCH_SIZE * max_epochs
            for step in range(1, max_training_steps+1):
                # check if env needs reset
                if agent_done:
                    obs = env.reset()
                    agent1.reset(obs)
                    avg.append(sum(viz_rewards))
                    print "Epoch:%d, accumulative rewards: %d"%(epoch, sum(viz_rewards))
                    visualize_training(visualizer, step, viz_rewards)
                    viz_rewards = []
                    epoch += 1
                    agent = agent_list[epoch/5 % 4]# change for every episodes
                    if epoch%10 == 0:
                        agent1.reset_collaborator()
                    s = 1
                    
                # select an action
                action = agent.act(obs, reward, agent_done, is_training = True)
                # take a step
                next_obs, reward, agent_done = env.do(action)
                agent1.collecting(obs, action, reward, next_obs, agent_done, s)
                s += 1
                obs = next_obs
                viz_rewards.append(reward)
                    
                if step % 100 == 0:
                    agent1.save_replay_buffer()
                #
                agent1.inject_summaries(step)
                
            print "Average Reward: ", 1.*sum(avg)/len(avg)
Ejemplo n.º 12
0
def agent_factory(name,
                  role,
                  clients,
                  max_epochs,
                  logdir,
                  visualizer,
                  manual=False):
    assert len(clients) >= 2, 'Not enough clients (need at least 2)'
    clients = parse_clients_args(clients)

    builder = PigChaseSymbolicStateBuilder()
    env = PigChaseEnvironment(clients,
                              builder,
                              actions=DanishPuppet.ACTIONS.all_commands(),
                              role=role,
                              human_speed=HUMAN_SPEED,
                              randomize_positions=True)

    # Default agent (challenger)
    c_agent = ChallengerFactory(name,
                                focused=True,
                                random=True,
                                bad_guy=False,
                                standstill=False)

    # Challenger  (Agent_1)
    if role == 0:

        agent_type = ChallengerFactory.get_agent_type(c_agent.current_agent)
        state = env.reset(agent_type)
        print("Agent Factory: Assigning {}.".format(
            type(c_agent.current_agent).__name__))

        reward = 0
        agent_done = False

        while True:

            # select an action
            action = c_agent.act(state, reward, agent_done, is_training=True)

            # reset if needed
            if env.done:
                agent_type = ChallengerFactory.get_agent_type(
                    c_agent.current_agent)
                _ = env.reset(agent_type)
                print("Agent Factory: Assigning {}.".format(
                    type(c_agent.current_agent).__name__))

            # take a step
            state, reward, agent_done = env.do(action)

    # Our Agent (Agent_2)
    else:
        c_agent = DanishPuppet(name=name,
                               helmets=c_agent.get_helmets(),
                               wait_for_pig=WAIT_FOR_PIG,
                               use_markov=USE_MARKOV)

        # Manual overwrite!
        if manual:
            c_agent.manual = True

        state = env.reset()
        reward = 0
        agent_done = False
        viz_rewards = []

        max_training_steps = EPOCH_SIZE * max_epochs
        for step in range(1, max_training_steps + 1):

            # check if env needs reset

            if env.done:
                try:
                    c_agent.note_game_end(reward_sequence=viz_rewards,
                                          state=state[0])
                except TypeError:
                    c_agent.note_game_end(reward_sequence=viz_rewards,
                                          state=None)
                print("")
                visualize_training(visualizer, step, viz_rewards)
                viz_rewards = []
                state = env.reset()

            # select an action
            action = None
            frame = None if not PASS_FRAME else env.frame
            while action is None:
                # for key, item in env.world_observations.items():
                #     print(key, ":", item)

                total_time = None
                if env is not None and env.world_observations is not None:
                    total_time = env.world_observations["TotalTime"]

                action = c_agent.act(state,
                                     reward,
                                     done=agent_done,
                                     total_time=total_time,
                                     is_training=True,
                                     frame=frame)

                # 'wait'
                if action == DanishPuppet.ACTIONS.wait:
                    action = None
                    sleep(4e-3)
                    state = env.state

            # take a step
            state, reward, agent_done = env.do(action)
            viz_rewards.append(reward)

            c_agent.inject_summaries(step)
def agent_factory(name, role, baseline_agent, clients, max_epochs, logdir,
                  visualizer):

    assert len(clients) >= 2, 'Not enough clients (need at least 2)'
    clients = parse_clients_args(clients)
    batch_size = 32

    builder = PigChaseSymbolicStateBuilder()
    env = PigChaseEnvironment(clients,
                              builder,
                              role=role,
                              randomize_positions=True)

    if role == 0:
        agent = PigChaseChallengeAgent(name)

        if type(agent.current_agent) == RandomAgent:
            agent_type = PigChaseEnvironment.AGENT_TYPE_1
        else:
            agent_type = PigChaseEnvironment.AGENT_TYPE_2
        ##Aqui el state hay que modificarlo para que se adapte a lo que la red neurnal necesita
        state = env.reset(agent_type)

        reward = 0
        agent_done = False
        num_actions = 0
        while True:

            # take a step

            # reset if needed
            if env.done:
                print(agent.check_memory(batch_size))
                if type(agent.current_agent) == RandomAgent:
                    agent_type = PigChaseEnvironment.AGENT_TYPE_1
                else:
                    agent_type = PigChaseEnvironment.AGENT_TYPE_2
                ##Aqui el state habria que modificarlo de nuevo

                if num_actions > batch_size:
                    print('Entrando a replay 1')
                    agent.replay(batch_size)
                state = env.reset(agent_type)

            # select an action
            #print('Accion del role 1')
            action = agent.act(state, reward, agent_done, is_training=True)
            next_state, reward, agent_done = env.do(action)
            num_actions = num_actions + 1
            next_state2 = adapt_state(next_state)
            agent.remember(state, action, reward, next_state2, agent_done)
            ##Aqui state= obs (que seria el estado anterior estado modificado)
            state = next_state
        ##No estoy seguro de si esto va aqui por el while true (no se cuando acaba). Deberia ir cuando acaba una partida
        ##Hacer check si hace el replay o no. Si no lo hace nunca, meter el replay dentro de el if(env.done (signifca que una etapa ha acabado y empieza otra, por lo que deberia esta bien))

    else:

        if baseline_agent == 'astar':
            agent = FocusedAgent(name, ENV_TARGET_NAMES[0])
        else:
            agent = RandomAgent(name, env.available_actions)

        state = env.reset()
        reward = 0
        agent_done = False
        viz_rewards = []

        max_training_steps = EPOCH_SIZE * max_epochs
        for step in six.moves.range(1, max_training_steps + 1):

            # check if env needs reset
            if env.done:

                visualize_training(visualizer, step, viz_rewards)
                viz_rewards = []
                ##No se si esto se tiene que hacer tambien aqui o no, hacer check
                if agent.check_memory(batch_size) > batch_size:
                    print('Entrando a replay 2')
                    agent.replay(batch_size)
                state = env.reset()

            # select an action
            #print('Accion del role 2')
            action = agent.act(state, reward, agent_done, is_training=True)
            # take a step
            next_state, reward, agent_done = env.do(action)
            next_state2 = adapt_state(next_state)
            agent.remember(state, action, reward, next_state2, agent_done)
            ##Aqui state= obs (que seria el estado anterior estado modificado)
            state = next_state
            #obs, reward, agent_done = env.do(action)
            viz_rewards.append(reward)

            agent.inject_summaries(step)
Ejemplo n.º 14
0
def agent_factory(name, role, baseline_agent, clients, max_epochs, logdir,
                  visualizer):

    assert len(clients) >= 2, 'Not enough clients (need at least 2)'
    clients = parse_clients_args(clients)

    builder = PigChaseSymbolicStateBuilder()
    env = PigChaseEnvironment(clients,
                              builder,
                              role=role,
                              randomize_positions=True)

    if role == 0:
        agent = PigChaseChallengeAgent(name)
        obs = env.reset(get_agent_type(agent))

        reward = 0
        agent_done = False

        while True:
            if env.done:
                while True:
                    obs = env.reset(get_agent_type(agent))
                    if obs:
                        break

            # select an action
            action = agent.act(obs, reward, agent_done, is_training=True)

            # reset if needed
            if env.done:
                obs = env.reset(get_agent_type(agent))

            # take a step
            obs, reward, agent_done = env.do(action)

    else:

        if baseline_agent == 'tabq':
            agent = TabularQLearnerAgent(name, visualizer)
        elif baseline_agent == 'astar':
            agent = FocusedAgent(name, ENV_TARGET_NAMES[0])
        else:
            agent = RandomAgent(name, env.available_actions)

        obs = env.reset()
        reward = 0
        agent_done = False
        viz_rewards = []

        max_training_steps = EPOCH_SIZE * max_epochs
        for step in six.moves.range(1, max_training_steps + 1):

            # check if env needs reset
            if env.done:
                while True:
                    if len(viz_rewards) == 0:
                        viz_rewards.append(0)
                    visualize_training(visualizer, step, viz_rewards)
                    tag = "Episode End Conditions"
                    visualizer.add_entry(
                        step, '%s/timeouts per episode' % tag,
                        env.end_result == "command_quota_reached")
                    visualizer.add_entry(
                        step, '%s/agent_1 defaults per episode' % tag,
                        env.end_result == "Agent_1_defaulted")
                    visualizer.add_entry(
                        step, '%s/agent_2 defaults per episode' % tag,
                        env.end_result == "Agent_2_defaulted")
                    visualizer.add_entry(step,
                                         '%s/pig caught per episode' % tag,
                                         env.end_result == "caught_the_pig")
                    agent.inject_summaries(step)
                    viz_rewards = []
                    obs = env.reset()
                    if obs:
                        break

            # select an action
            action = agent.act(obs, reward, agent_done, is_training=True)
            # take a step
            obs, reward, agent_done = env.do(action)
            viz_rewards.append(reward)