Ejemplo n.º 1
0
    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        nonlocal step, stats, cumulative_steps
        timings = prof.Timings()
        while step < flags.total_steps:
            timings.reset()
            batch, agent_state = get_batch(
                flags,
                free_queue,
                full_queue,
                buffers,
                initial_agent_state_buffers,
                timings,
            )
            stats = learn(flags, model, learner_model, batch, agent_state,
                          optimizer, scheduler)
            timings.time("learn")
            with lock:
                to_log = dict(step=step)
                to_log.update({k: stats[k] for k in stat_keys})
                plogger.log(to_log)
                step += T * B
                cumulative_steps[0] = step

        if i == 0:
            logging.info("Batch and learn: %s", timings.summary())
Ejemplo n.º 2
0
def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
        model: torch.nn.Module, buffers: Buffers, flags):
    try:
        logging.info('Actor %i started.', i)
        timings = prof.Timings()  # Keep track of how fast things are.

        gym_env = Net.create_env(flags)
        seed = i ^ int.from_bytes(os.urandom(4), byteorder='little')
        gym_env.seed(seed)
        env = environment.Environment(gym_env)
        env_output = env.initial()
        agent_output = model(env_output)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]

            # Do new rollout
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output = model(env_output)

                timings.time('model')

                env_output = env.step(agent_output['action'])

                timings.time('step')

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time('write')
            full_queue.put(index)

        if i == 0:
            logging.info('Actor %i: %s', i, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error('Exception in worker process %i', i)
        traceback.print_exc()
        print()
        raise e
Ejemplo n.º 3
0
    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        nonlocal frames, stats
        timings = prof.Timings()
        while frames < flags.total_frames:
            timings.reset()
            batch = get_batch(free_queue, full_queue, buffers, flags, timings)

            stats = learn(model, learner_model, batch, optimizer, scheduler,
                          flags)
            timings.time('learn')
            with lock:
                to_log = dict(frames=frames)
                to_log.update({k: stats[k] for k in stat_keys})
                plogger.log(to_log)
                frames += T * B

        if i == 0:
            logging.info('Batch and learn: %s', timings.summary())
Ejemplo n.º 4
0
    def __init__(self, flags, actor_index):
        self.flags = flags
        self.actor_index = actor_index
        self.logger = logging.getLogger(__name__)
        self.logger.error("Actor %i started.", self.actor_index)

        self.timings = prof.Timings()  # Keep track of how fast things are.

        self.gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        self.gym_env.seed(seed)
        self.env = environment.Environment(self.gym_env)
        self.env_output = self.env.initial()

        self.net = AtariNet(self.gym_env.observation_space.shape[0],
                            self.gym_env.action_space.n, self.flags.use_lstm,
                            self.flags.use_resnet)
        self.agent_state = self.net.initial_state(batch_size=1)
        self.agent_output, _ = self.net(self.env_output, self.agent_state)
        self.params_idx = -1
Ejemplo n.º 5
0
def act(
    flags,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
    initial_agent_state_buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)
        env = environment.Environment(gym_env)
        env_output = env.initial()

        #agent_ model.initialize(env_output)
        agent_state = model.initialize(env_output, batch_size=1)
        agent_output, unused_state = model.act(env_output, agent_state)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state["core_state"]):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output, agent_state = model.act(
                        env_output, agent_state)

                timings.time("model")

                env_output = env.step(agent_output["action"])

                timings.time("step")

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time("write")
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e
Ejemplo n.º 6
0
        def batch_and_learn(learner_idx):
            """Thread target for the learning process."""
            nonlocal step, stats, params_idx, params_id, rollouts, trajectories, last_log_time, last_time, start_step
            timings = prof.Timings()
            batch = []

            while step < self.flags.total_steps:
                # Obtain batch of data. Learners don't do this in parallel to ensure maximal throughput.
                with batch_lock:
                    while len(batch) < self.flags.batch_size:
                        done_id, rollouts = ray.wait(rollouts)

                        # get the results of the task from the object store
                        rollout, actor_id = ray.get(done_id)[0]
                        batch.append(rollout)
                        # start a new task on the same actor object
                        rollouts.extend([self.actors[actor_id].act.remote(params_id, params_idx)])

                        if self.replay_memory is not None:
                            # add trajectory to replay memory
                            for idx in range(self.flags.unroll_length + 1):
                                transition = {}
                                for key in rollout:
                                    if key not in ['initial_agent_state', 'mask']:
                                        transition[key] = rollout[key][idx]
                                trajectories[actor_id].append(transition)
                                if transition["done"]:
                                    self.replay_memory.add_trajectory(trajectories[actor_id])
                                    trajectories[actor_id] = []

                    timings.time("dequeue")

                batch, agent_state, used_replay = get_batch(learner_idx, batch, self.replay_memory, timings, self.flags, step)

                # Perform a learning step and update the network parameters.
                with learn_lock:
                    tmp_stats, agent_state, mask = learn(
                        self.flags, self.net, batch, agent_state, self.optimizer, self.scheduler
                    )
                    params_idx += 1
                    params_id = ray.put({k: v.cpu() for k, v in self.net.state_dict().items()})
                    timings.time("learn")
                batch = []

                # For LASER (https://arxiv.org/abs/1909.11583) update the replay status. As the memory treats each
                # learner individually, they can access the data structures in parallel.
                if used_replay:
                    self.replay_memory.update_state_and_status(
                        learner_idx,
                        tuple(t[:, self.flags.batch_size:].cpu().detach() for t in agent_state),
                        mask[-1, self.flags.batch_size:].cpu())

                timings.time("update replay")

                # Logging results in updating the step counter AND printing to console.
                # This requires locking for concurrency.
                with log_lock:
                    step += self.flags.unroll_length * self.flags.batch_size

                    if 'episode_returns' in stats.keys():
                        tmp_stats['episode_returns'] += stats['episode_returns']
                    stats = tmp_stats

                    if self.timer() - last_log_time > 5:
                        last_log_time = self.timer()
                        print_tuple = get_stats(self.flags, stats, self.timer, last_time, step, start_step)
                        start_step = step
                        last_time = self.timer()
                        log(self.flags, print_tuple, self.logger, step)
                        timings.time("log")

            if learner_idx == 0:
                self.logger.error("Batch and learn: %s", timings.summary())
Ejemplo n.º 7
0
def test(flags,
         env,
         cumulative_steps,
         model: torch.nn.Module,
         split='test',
         logger=None):
    try:
        logging.info("Tester started.")
        timings = prof.Timings()  # Keep track of how fast things are.
        if flags.seed is None:
            seed = int.from_bytes(os.urandom(4), byteorder="little")
        else:
            seed = flags.seed
        utils_seed(seed)

        env.eval(split=split)

        last_cumulative_steps = -1

        greedy = True

        while True:
            cumulative_steps_now = cumulative_steps.item()
            if last_cumulative_steps < 0 \
                or (cumulative_steps_now - last_cumulative_steps >= flags.test_interval) \
                or cumulative_steps_now >= flags.total_steps:

                logging.info(f"tester: {cumulative_steps.item()} steps")
                env_output = env.initial()
                agent_state = model.initial_state(batch_size=1)
                agent_output, unused_state = model(env_output,
                                                   agent_state,
                                                   greedy=greedy)

                episode_count = 0
                mean_episode_return = 0
                while episode_count < flags.num_test_episodes:
                    # Do new rollout.
                    for t in range(flags.unroll_length):
                        timings.reset()

                        with torch.no_grad():
                            agent_output, agent_state = model(env_output,
                                                              agent_state,
                                                              greedy=greedy)
                        timings.time("model")

                        env_output = env.step(agent_output["action"])
                        timings.time("step")

                        if env_output["done"][0]:
                            episode_count += 1
                            mean_episode_return += env_output[
                                'episode_return'].item()

                mean_episode_return /= flags.num_test_episodes

                last_cumulative_steps = cumulative_steps_now

                to_log = {
                    'step': cumulative_steps_now,
                    'mean_episode_return': mean_episode_return
                }
                if split == 'test':
                    logger.log_test_test(to_log)
                else:
                    logger.log_test_train(to_log)
                logging.info(f"Tester {split}: %s", timings.summary())

            if cumulative_steps_now >= flags.total_steps:
                logging.info("Tester shutting down.")
                break

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process: tester")
        traceback.print_exc()
        print()
        raise e
Ejemplo n.º 8
0
def act(flags, actor_index: str, channel):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)

        flags.num_actions = gym_env.action_space.n

        flags.device = torch.device("cpu")

        model = Net(gym_env.observation_space.shape,
                    num_actions=flags.num_actions).to(device=flags.device)
        #model.eval()

        buffers = create_buffers(flags, gym_env.observation_space.shape,
                                 model.num_actions)

        env = environment.Environment(gym_env)
        env_output = env.initial()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)

        # pull index for model update

        pull_index = 0

        while True:
            # Write old rollout end.
            for key in env_output:
                buffers[key][0][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][0][0, ...] = agent_output[key]

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    if (flags.cut_layer < model.total_cut_layers - 1):
                        inter_tensors, inter_T, inter_B = model(
                            env_output, agent_state, cut_layer=flags.cut_layer)
                        agent_output, agent_state = rpcenv.inference_send(
                            inter_tensors, agent_state, flags.cut_layer,
                            inter_T, inter_B, env_output["reward"], channel)
                    else:
                        agent_output, agent_state = model(
                            env_output, agent_state)

                timings.time("model")

                env_output = env.step(agent_output["action"])

                timings.time("step")

                for key in env_output:
                    buffers[key][0][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][0][t + 1, ...] = agent_output[key]

                timings.time("write")

            # rpc send buffers to learner
            rpcenv.upload_trajectory(actor_index, buffers, channel)

            pull_index = pull_index + 1

            if (pull_index == flags.batch_size):
                parameters = rpcenv.pull_model(actor_index, channel)
                logging.info("update model !!")
                model.load_state_dict(parameters)
                logging.info("update model from learner in %i steps",
                             env_output["episode_step"])
                logging.info("model return in %f",
                             env_output["episode_return"])

                # pull index for model update

                pull_index = 0

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e