Example #1
0
    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        nonlocal step, stats
        timings = prof.Timings()
        while step < flags.total_steps:
            timings.reset()
            batch = get_batch(
                flags,
                free_queue,
                full_queue,
                buffers,
                timings,
            )
            stats = learn(flags, model, learner_model, batch, optimizer,
                          scheduler)
            timings.time("learn")
            with lock:
                to_log = dict(step=step)
                to_log.update({k: stats[k] for k in stat_keys})
                plogger.log(to_log)
                step += T  #* B # just count the parallel steps

    # end batch_and_learn

        if i == 0:
            logging.info("Batch and learn: %s", timings.summary())
Example #2
0
    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        nonlocal step, stats
        timings = prof.Timings()
        steps = flags.total_steps
        if 'train_steps' in flags:
            steps = flags.train_steps

        while step < steps:
            timings.reset()
            batch, agent_state = get_batch(
                flags,
                free_queue,
                full_queue,
                buffers,
                initial_agent_state_buffers,
                timings,
            )
            stats = learn(
                flags, model, learner_model, batch, agent_state, optimizer, scheduler
            )
            timings.time("learn")
            with lock:
                to_log = dict(step=step)
                to_log.update({k: stats[k] for k in stat_keys})
                plogger.log(to_log)

                step += T * B

        if i == 0:
            logging.info("Batch and learn: %s", timings.summary())
def act(flags, actor_index: int, free_queue: mp.SimpleQueue,
        full_queue: mp.SimpleQueue, model: torch.nn.Module, buffers: Buffers,
        initial_agent_state_buffers, level_name):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        ######changed next line
        gym_env = create_env(flags, level_name, seed)
        env = environment.Environment(gym_env)
        env_output = env.initial()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                timings.time("model")

                env_output = env.step(agent_output["action"])

                timings.time("step")

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time("write")
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e
    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        # step in particular needs to be from the outside scope, since all learner threads can update
        # it and all learners should stop once the total number of steps/frames has been processed
        nonlocal step, stats
        timings = prof.Timings()
        while step < flags.total_steps:
            timings.reset()
            batch, agent_state = get_batch(
                flags,
                free_queue,
                full_queue,
                buffers,
                initial_agent_state_buffers,
                timings,
            )
            learn(flags,
                  model,
                  learner_model,
                  batch,
                  agent_state,
                  optimizer,
                  scheduler,
                  stats,
                  envs=environments)
            timings.time("learn")
            with lock:
                to_log = dict(step=step)
                to_log.update(
                    {k: stats[k]
                     for k in stat_keys if "_step" not in k})
                for e in stats["env_step"]:
                    to_log["{}_step".format(e)] = stats["env_step"][e]
                plogger.log(to_log)
                step += T * B  # so this counts the number of frames, not e.g. trajectories/rollouts

        if i == 0:
            logging.info("Batch and learn: %s", timings.summary())
Example #5
0
def act(
    flags,
    game_params,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        sc_env = init_game(game_params['env'],
                           flags.map_name,
                           random_seed=seed)
        obs_processer = IMPALA_ObsProcesser(action_table=model.action_table,
                                            **game_params['obs_processer'])
        env = environment.Environment(sc_env, obs_processer, seed)
        # initial rollout starts here
        env_output = env.initial()
        with torch.no_grad():
            agent_output = model.actor_step(env_output)

        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                if key not in ['sc_env_action'
                               ]:  # no need to save this key on buffers
                    buffers[key][index][0, ...] = agent_output[key]

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                env_output = env.step(agent_output["sc_env_action"])

                timings.time("step")

                with torch.no_grad():
                    agent_output = model.actor_step(env_output)

                timings.time("model")

                #env_output = env.step(agent_output["sc_env_action"])

                #timings.time("step")

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    if key not in ['sc_env_action'
                                   ]:  # no need to save this key on buffers
                        buffers[key][index][t + 1, ...] = agent_output[key]
                # env_output will be like
                # s_{0}, ..., s_{T}
                # act_mask_{0}, ..., act_mask_{T}
                # discount_{0}, ..., discount_{T}
                # r_{-1}, ..., r_{T-1}
                # agent_output will be like
                # a_0, ..., a_T with a_t ~ pi(.|s_t)
                # log_pi(a_0|s_0), ..., log_pi(a_T|s_T)
                # so the learner can use (s_i, act_mask_i) to predict log_pi_i
                timings.time("write")
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e
def act(
    flags,
    env: str,
    task: int,
    full_action_space: bool,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
    initial_agent_state_buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        # create the environment from command line parameters
        # => could also create a special one which operates on a list of games (which we need)
        gym_env = create_env(
            env,
            frame_height=flags.frame_height,
            frame_width=flags.frame_width,
            gray_scale=(flags.aaa_input_format == "gray_stack"),
            full_action_space=full_action_space,
            task=task)

        # generate a seed for the environment (NO HUMAN STARTS HERE!), could just
        # use this for all games wrapped by the environment for our application
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)

        # wrap the environment, this is actually probably the point where we could
        # use multiple games, because the other environment is still one from Gym
        env = environment.Environment(gym_env)

        # get the initial frame, reward, done, return, step, last_action
        env_output = env.initial()

        # perform the first step
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)
        while True:
            # get a buffer index from the queue for free buffers (?)
            index = free_queue.get()
            # termination signal (?) for breaking out of this loop
            if index is None:
                break

            # Write old rollout end.
            # the keys here are (frame, reward, done, episode_return, episode_step, last_action)
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            # here the keys are (policy_logits, baseline, action)
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            # I think the agent_state is just the RNN/LSTM state (which will be the "initial" state for the next step)
            # not sure why it's needed though because it really just seems to be the initial state before starting to
            # act; however, it might be randomly initialised, which is why we might want it...
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout
            for t in range(flags.unroll_length):
                timings.reset()

                # forward pass without keeping track of gradients to get the agent action
                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                timings.time("model")

                # agent acting in the environment
                env_output = env.step(agent_output["action"])

                timings.time("step")

                # writing the respective outputs of the current step (see above for the list of keys)
                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time("write")

            # after finishing a trajectory put the index in the "full queue",
            # presumably so that the data can be processed/sent to the learner
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e
Example #7
0
def act(
    flags,
    game_params,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
    initial_agent_state_buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        sc_env = init_game(game_params['env'], flags.map_name, random_seed=seed)
        obs_processer = IMPALA_ObsProcesser_v2(env=sc_env, action_table=model.action_table, **game_params['obs_processer'])
        env = environment.Environment_v2(sc_env, obs_processer, seed)
        # initial rollout starts here
        env_output = env.initial() 
        new_res = model.spatial_processing_block.new_res
        agent_state = model.spatial_processing_block.conv_lstm._init_hidden(batch_size=1, 
                                                                            image_size=(new_res,new_res)
                                                                           )
        
        with torch.no_grad():
            agent_output, new_agent_state = model.actor_step(env_output, *agent_state[0]) 

        agent_state = agent_state[0] # _init_hidden yields [(h,c)], whereas actor step only (h,c)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end. 
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                if key not in ['sc_env_action']: # no need to save this key on buffers
                    buffers[key][index][0, ...] = agent_output[key]
            
            # lstm state in syncro with the environment / input to the agent 
            # that's why agent_state = new_agent_state gets executed afterwards
            initial_agent_state_buffers[index][0][...] = agent_state[0]
            initial_agent_state_buffers[index][1][...] = agent_state[1]
            
            
            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                env_output = env.step(agent_output["sc_env_action"])
                
                timings.time("step")
                
                # update state
                agent_state = new_agent_state 
            
                with torch.no_grad():
                    agent_output, new_agent_state = model.actor_step(env_output, *agent_state)
                
                timings.time("model")
                
                #env_output = env.step(agent_output["sc_env_action"])

                #timings.time("step")

                for key in env_output:
                    buffers[key][index][t+1, ...] = env_output[key] 
                for key in agent_output:
                    if key not in ['sc_env_action']: # no need to save this key on buffers
                        buffers[key][index][t+1, ...] = agent_output[key] 
                # env_output will be like
                # s_{0}, ..., s_{T}
                # act_mask_{0}, ..., act_mask_{T}
                # discount_{0}, ..., discount_{T}
                # r_{-1}, ..., r_{T-1}
                # agent_output will be like
                # a_0, ..., a_T with a_t ~ pi(.|s_t)
                # log_pi(a_0|s_0), ..., log_pi(a_T|s_T)
                # so the learner can use (s_i, act_mask_i) to predict log_pi_i
                timings.time("write")
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e
Example #8
0
def act(flags, gym_env, actor_index: int, free_queue: mp.SimpleQueue,
        full_queue: mp.SimpleQueue, buffers: Buffers, actor_buffers: Buffers,
        actor_model_queues: List[mp.SimpleQueue],
        actor_env_queues: List[mp.SimpleQueue]):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        gym_env = gym_env
        #seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        #gym_env.seed(seed)
        if flags.agent in ["CNN"]:
            env = environment.Environment(gym_env, "image")
        elif flags.agent in ["NLM", "KBMLP", "GCN"]:
            if flags.state in ["relative", "integer", "block"]:
                env = environment.Environment(gym_env, "VKB")
            elif flags.state == "absolute":
                env = environment.Environment(gym_env, "absVKB")
        env_output = env.initial()
        for key in env_output:
            actor_buffers[key][actor_index][0] = env_output[key]
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in actor_buffers:
                buffers[key][index][0] = actor_buffers[key][actor_index][0]

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                actor_model_queues[actor_index].put(actor_index)
                env_info = actor_env_queues[actor_index].get()
                if env_info == "exit":
                    return

                timings.time("model")

                env_output = env.step(actor_buffers["action"][actor_index][0])

                timings.time("step")

                for key in actor_buffers:
                    buffers[key][index][t +
                                        1] = actor_buffers[key][actor_index][0]
                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in env_output:
                    actor_buffers[key][actor_index][0] = env_output[key]

                timings.time("write")

            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e