Beispiel #1
0
    def forward_pass(device_type):
        env_name = 'atari_breakout'
        cfg = default_cfg(algo='appooc', env=env_name)
        cfg.actor_critic_share_weights = True
        cfg.hidden_size = 128
        cfg.use_rnn = True
        cfg.env_framestack = 4

        env = create_env(env_name, cfg=cfg)

        torch.set_num_threads(1)
        torch.backends.cudnn.benchmark = True

        actor_critic = create_actor_critic(cfg, env.observation_space, env.action_space)
        device = torch.device(device_type)
        actor_critic.to(device)

        timing = Timing()
        with timing.timeit('all'):
            batch = 128
            with timing.add_time('input'):
                # better avoid hardcoding here...
                observations = dict(obs=torch.rand([batch, 4, 84, 84]).to(device))
                rnn_states = torch.rand([batch, get_hidden_size(cfg)]).to(device)

            n = 200
            for i in range(n):
                with timing.add_time('forward'):
                    output = actor_critic(observations, rnn_states)

                log.debug('Progress %d/%d', i, n)

        log.debug('Timing: %s', timing)
    def test_gumbel_trick(self):
        """
        We use a Gumbel noise which seems to be faster compared to using pytorch multinomial.
        Here we test that those are actually equivalent.
        """

        timing = Timing()

        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True

        with torch.no_grad():
            action_space = gym.spaces.Discrete(8)
            num_logits = calc_num_logits(action_space)
            device_type = 'cpu'
            device = torch.device(device_type)
            logits = torch.rand(self.batch_size, num_logits,
                                device=device) * 10.0 - 5.0

            if device_type == 'cuda':
                torch.cuda.synchronize(device)

            count_gumbel, count_multinomial = np.zeros(
                [action_space.n]), np.zeros([action_space.n])

            # estimate probability mass by actually sampling both ways
            num_samples = 20000

            action_distribution = get_action_distribution(action_space, logits)
            sample_actions_log_probs(action_distribution)
            action_distribution.sample_gumbel()

            with timing.add_time('gumbel'):
                for i in range(num_samples):
                    action_distribution = get_action_distribution(
                        action_space, logits)
                    samples_gumbel = action_distribution.sample_gumbel()
                    count_gumbel[samples_gumbel[0]] += 1

            action_distribution = get_action_distribution(action_space, logits)
            action_distribution.sample()

            with timing.add_time('multinomial'):
                for i in range(num_samples):
                    action_distribution = get_action_distribution(
                        action_space, logits)
                    samples_multinomial = action_distribution.sample()
                    count_multinomial[samples_multinomial[0]] += 1

            estimated_probs_gumbel = count_gumbel / float(num_samples)
            estimated_probs_multinomial = count_multinomial / float(
                num_samples)

            log.debug('Gumbel estimated probs: %r', estimated_probs_gumbel)
            log.debug('Multinomial estimated probs: %r',
                      estimated_probs_multinomial)
            log.debug('Sampling timing: %s', timing)
            time.sleep(0.1)  # to finish logging
Beispiel #3
0
    def _run(self):
        """
        Main loop of the actor worker (rollout worker).
        Process tasks (mainly ROLLOUT_STEP) until we get the termination signal, which usually means end of training.
        Currently there is no mechanism to restart dead workers if something bad happens during training. We can only
        retry on the initial reset(). This is definitely something to work on.
        """
        log.info('Initializing vector env runner %d...', self.worker_idx)

        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        torch.multiprocessing.set_sharing_strategy('file_system')

        timing = Timing()

        last_report = time.time()
        with torch.no_grad():
            while not self.terminate:
                try:
                    try:
                        with timing.add_time('waiting'), timing.timeit('wait_actor'):
                            tasks = self.task_queue.get_many(timeout=0.1)
                    except Empty:
                        tasks = []

                    for task in tasks:
                        task_type, data = task

                        if task_type == TaskType.INIT:
                            self._init()
                            continue

                        if task_type == TaskType.TERMINATE:
                            self._terminate()
                            break

                        # handling actual workload
                        if task_type == TaskType.ROLLOUT_STEP:
                            if 'work' not in timing:
                                timing.waiting = 0  # measure waiting only after real work has started

                            with timing.add_time('work'), timing.timeit('one_step'):
                                self._advance_rollouts(data, timing)
                        elif task_type == TaskType.RESET:
                            with timing.add_time('reset'):
                                self._handle_reset()
                        elif task_type == TaskType.PBT:
                            self._process_pbt_task(data)

                    if time.time() - last_report > 5.0 and 'one_step' in timing:
                        timing_stats = dict(wait_actor=timing.wait_actor, step_actor=timing.one_step)
                        memory_mb = memory_consumption_mb()
                        stats = dict(memory_actor=memory_mb)
                        self.report_queue.put(dict(timing=timing_stats, stats=stats))
                        last_report = time.time()

                except RuntimeError as exc:
                    log.warning('Error while processing data w: %d, exception: %s', self.worker_idx, exc)
                    log.warning('Terminate process...')
                    self.terminate = True
                    self.report_queue.put(dict(critical_error=self.worker_idx))
                except KeyboardInterrupt:
                    self.terminate = True
                except:
                    log.exception('Unknown exception in rollout worker')
                    self.terminate = True

        if self.worker_idx <= 1:
            time.sleep(0.1)
            log.info(
                'Env runner %d, CPU aff. %r, rollouts %d: timing %s',
                self.worker_idx, psutil.Process().cpu_affinity(), self.num_complete_rollouts, timing,
            )
Beispiel #4
0
    def _run(self):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        try:
            psutil.Process().nice(self.cfg.default_niceness)
        except psutil.AccessDenied:
            log.error('Low niceness requires sudo!')

        if self.cfg.device == 'gpu':
            cuda_envvars(self.policy_id)

        torch.multiprocessing.set_sharing_strategy('file_system')
        torch.set_num_threads(self.cfg.learner_main_loop_num_cores)

        timing = Timing()

        rollouts = []

        if self.train_in_background:
            self.training_thread.start()
        else:
            self.initialize(timing)
            log.error(
                'train_in_background set to False on learner %d! This is slow, use only for testing!',
                self.policy_id,
            )

        while not self.terminate:
            while True:
                try:
                    tasks = self.task_queue.get_many(timeout=0.005)

                    for task_type, data in tasks:
                        if task_type == TaskType.TRAIN:
                            with timing.add_time('extract'):
                                rollouts.extend(self._extract_rollouts(data))
                                # log.debug('Learner %d has %d rollouts', self.policy_id, len(rollouts))
                        elif task_type == TaskType.INIT:
                            self._init()
                        elif task_type == TaskType.TERMINATE:
                            time.sleep(0.3)
                            log.info('GPU learner timing: %s', timing)
                            self._terminate()
                            break
                        elif task_type == TaskType.PBT:
                            self._process_pbt_task(data)
                except Empty:
                    break

            if self._accumulated_too_much_experience(rollouts):
                # if we accumulated too much experience, signal the policy workers to stop experience collection
                if not self.stop_experience_collection[self.policy_id]:
                    log.debug(
                        'Learner %d accumulated too much experience, stop experience collection!',
                        self.policy_id)
                self.stop_experience_collection[self.policy_id] = True
            elif self.stop_experience_collection[self.policy_id]:
                # otherwise, resume the experience collection if it was stopped
                self.stop_experience_collection[self.policy_id] = False
                with self.resume_experience_collection_cv:
                    log.debug('Learner %d is resuming experience collection!',
                              self.policy_id)
                    self.resume_experience_collection_cv.notify_all()

            with torch.no_grad():
                rollouts = self._process_rollouts(rollouts, timing)

            if not self.train_in_background:
                while not self.experience_buffer_queue.empty():
                    training_data = self.experience_buffer_queue.get()
                    self._process_training_data(training_data, timing)

            self._experience_collection_rate_stats()

        if self.train_in_background:
            self.experience_buffer_queue.put(None)
            self.training_thread.join()
Beispiel #5
0
    def sample(self, proc_idx):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        timing = Timing()

        from threadpoolctl import threadpool_limits
        with threadpool_limits(limits=1, user_api=None):
            if self.cfg.set_workers_cpu_affinity:
                set_process_cpu_affinity(proc_idx, self.cfg.num_workers)

            initial_cpu_affinity = psutil.Process().cpu_affinity(
            ) if platform != 'darwin' else None
            psutil.Process().nice(10)

            with timing.timeit('env_init'):
                envs = []
                env_key = ['env' for _ in range(self.cfg.num_envs_per_worker)]

                for env_idx in range(self.cfg.num_envs_per_worker):
                    global_env_id = proc_idx * self.cfg.num_envs_per_worker + env_idx
                    env_config = AttrDict(worker_index=proc_idx,
                                          vector_index=env_idx,
                                          env_id=global_env_id)
                    env = create_env(self.cfg.env,
                                     cfg=self.cfg,
                                     env_config=env_config)
                    log.debug(
                        'CPU affinity after create_env: %r',
                        psutil.Process().cpu_affinity()
                        if platform != 'darwin' else 'MacOS - None')
                    env.seed(global_env_id)
                    envs.append(env)

                    # this is to track the performance for individual DMLab levels
                    if hasattr(env.unwrapped, 'level_name'):
                        env_key[env_idx] = env.unwrapped.level_name

                episode_length = [0 for _ in envs]
                episode_lengths = [deque([], maxlen=20) for _ in envs]

            try:
                with timing.timeit('first_reset'):
                    for env_idx, env in enumerate(envs):
                        env.reset()
                        log.info('Process %d finished resetting %d/%d envs',
                                 proc_idx, env_idx + 1, len(envs))

                    self.report_queue.put(
                        dict(proc_idx=proc_idx, finished_reset=True))

                self.start_event.wait()

                with timing.timeit('work'):
                    last_report = last_report_frames = total_env_frames = 0
                    while not self.terminate.value and total_env_frames < self.cfg.sample_env_frames_per_worker:
                        for env_idx, env in enumerate(envs):
                            action = env.action_space.sample()
                            with timing.add_time(f'{env_key[env_idx]}.step'):
                                obs, reward, done, info = env.step(action)

                            num_frames = info.get('num_frames', 1)
                            total_env_frames += num_frames
                            episode_length[env_idx] += num_frames

                            if done:
                                with timing.add_time(
                                        f'{env_key[env_idx]}.reset'):
                                    env.reset()

                                episode_lengths[env_idx].append(
                                    episode_length[env_idx])
                                episode_length[env_idx] = 0

                        with timing.add_time('report'):
                            now = time.time()
                            if now - last_report > self.report_every_sec:
                                last_report = now
                                frames_since_last_report = total_env_frames - last_report_frames
                                last_report_frames = total_env_frames
                                self.report_queue.put(
                                    dict(proc_idx=proc_idx,
                                         env_frames=frames_since_last_report))

                # Extra check to make sure cpu affinity is preserved throughout the execution.
                # I observed weird effect when some environments tried to alter affinity of the current process, leading
                # to decreased performance.
                # This can be caused by some interactions between deep learning libs, OpenCV, MKL, OpenMP, etc.
                # At least user should know about it if this is happening.
                cpu_affinity = psutil.Process().cpu_affinity(
                ) if platform != 'darwin' else None
                assert initial_cpu_affinity == cpu_affinity, \
                    f'Worker CPU affinity was changed from {initial_cpu_affinity} to {cpu_affinity}!' \
                    f'This can significantly affect performance!'

            except:
                log.exception('Unknown exception')
                log.error('Unknown exception in worker %d, terminating...',
                          proc_idx)
                self.report_queue.put(dict(proc_idx=proc_idx, crash=True))

            time.sleep(proc_idx * 0.01 + 0.01)
            log.info('Process %d finished sampling. Timing: %s', proc_idx,
                     timing)

            for env_idx, env in enumerate(envs):
                if len(episode_lengths[env_idx]) > 0:
                    log.warning('Level %s avg episode len %d',
                                env_key[env_idx],
                                np.mean(episode_lengths[env_idx]))

            for env in envs:
                env.close()
Beispiel #6
0
    def _run(self):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        psutil.Process().nice(min(self.cfg.default_niceness + 2, 20))

        cuda_envvars(self.policy_id)
        torch.multiprocessing.set_sharing_strategy('file_system')

        timing = Timing()

        with timing.timeit('init'):
            # initialize the Torch modules
            log.info('Initializing model on the policy worker %d-%d...',
                     self.policy_id, self.worker_idx)

            torch.set_num_threads(1)

            if self.cfg.device == 'gpu':
                # we should already see only one CUDA device, because of env vars
                assert torch.cuda.device_count() == 1
                self.device = torch.device('cuda', index=0)
            else:
                self.device = torch.device('cpu')

            self.actor_critic = create_actor_critic(self.cfg, self.obs_space,
                                                    self.action_space, timing)
            self.actor_critic.model_to_device(self.device)
            for p in self.actor_critic.parameters():
                p.requires_grad = False  # we don't train anything here

            log.info('Initialized model on the policy worker %d-%d!',
                     self.policy_id, self.worker_idx)

        last_report = last_cache_cleanup = time.time()
        last_report_samples = 0
        request_count = deque(maxlen=50)

        # very conservative limit on the minimum number of requests to wait for
        # this will almost guarantee that the system will continue collecting experience
        # at max rate even when 2/3 of workers are stuck for some reason (e.g. doing a long env reset)
        # Although if your workflow involves very lengthy operations that often freeze workers, it can be beneficial
        # to set min_num_requests to 1 (at a cost of potential inefficiency, i.e. policy worker will use very small
        # batches)
        min_num_requests = self.cfg.num_workers // (
            self.cfg.num_policies * self.cfg.policy_workers_per_policy)
        min_num_requests //= 3
        min_num_requests = max(1, min_num_requests)

        # Again, very conservative timer. Only wait a little bit, then continue operation.
        wait_for_min_requests = 0.025

        while not self.terminate:
            try:
                while self.stop_experience_collection[self.policy_id]:
                    with self.resume_experience_collection_cv:
                        self.resume_experience_collection_cv.wait(timeout=0.05)

                waiting_started = time.time()
                while len(self.requests) < min_num_requests and time.time(
                ) - waiting_started < wait_for_min_requests:
                    try:
                        with timing.timeit('wait_policy'), timing.add_time(
                                'wait_policy_total'):
                            policy_requests = self.policy_queue.get_many(
                                timeout=0.005)
                        self.requests.extend(policy_requests)
                    except Empty:
                        pass

                self._update_weights(timing)

                with timing.timeit('one_step'), timing.add_time(
                        'handle_policy_step'):
                    if self.initialized:
                        if len(self.requests) > 0:
                            request_count.append(len(self.requests))
                            self._handle_policy_steps(timing)

                try:
                    task_type, data = self.task_queue.get_nowait()

                    # task from the task_queue
                    if task_type == TaskType.INIT:
                        self._init()
                    elif task_type == TaskType.TERMINATE:
                        self.terminate = True
                        break
                    elif task_type == TaskType.INIT_MODEL:
                        self._init_model(data)

                    self.task_queue.task_done()
                except Empty:
                    pass

                if time.time() - last_report > 3.0 and 'one_step' in timing:
                    timing_stats = dict(wait_policy=timing.wait_policy,
                                        step_policy=timing.one_step)
                    samples_since_last_report = self.total_num_samples - last_report_samples

                    stats = memory_stats('policy_worker', self.device)
                    if len(request_count) > 0:
                        stats['avg_request_count'] = np.mean(request_count)

                    self.report_queue.put(
                        dict(
                            timing=timing_stats,
                            samples=samples_since_last_report,
                            policy_id=self.policy_id,
                            stats=stats,
                        ))
                    last_report = time.time()
                    last_report_samples = self.total_num_samples

                if time.time() - last_cache_cleanup > 300.0 or (
                        not self.cfg.benchmark
                        and self.total_num_samples < 1000):
                    if self.cfg.device == 'gpu':
                        torch.cuda.empty_cache()
                    last_cache_cleanup = time.time()

            except KeyboardInterrupt:
                log.warning('Keyboard interrupt detected on worker %d-%d',
                            self.policy_id, self.worker_idx)
                self.terminate = True
            except:
                log.exception('Unknown exception on policy worker')
                self.terminate = True

        time.sleep(0.2)
        log.info('Policy worker avg. requests %.2f, timing: %s',
                 np.mean(request_count), timing)
Beispiel #7
0
    def train(self, buffer, env_steps, agent, timing=None):
        if timing is None:
            timing = Timing()

        params = agent.params

        batch_size = params.distance_batch_size
        summary = None
        dist_step = self.step.eval(session=agent.session)

        prev_loss = 1e10
        num_epochs = params.distance_train_epochs

        log.info('Train distance net %d pairs, batch %d, epochs %d',
                 len(buffer), batch_size, num_epochs)

        with timing.timeit('dist_epochs'):
            for epoch in range(num_epochs):
                losses = []

                with timing.add_time('shuffle'):
                    buffer.shuffle_data()

                obs_first, obs_second, labels = buffer.obs_first, buffer.obs_second, buffer.labels

                with timing.add_time('batch'):
                    for i in range(0, len(obs_first) - 1, batch_size):
                        # noinspection PyProtectedMember
                        with_summaries = agent._should_write_summaries(
                            dist_step) and summary is None
                        summaries = [self.summaries] if with_summaries else []

                        start, end = i, i + batch_size

                        result = agent.session.run([self.loss, self.train_op] +
                                                   summaries,
                                                   feed_dict={
                                                       self.ph_obs_first:
                                                       obs_first[start:end],
                                                       self.ph_obs_second:
                                                       obs_second[start:end],
                                                       self.ph_labels:
                                                       labels[start:end],
                                                       self.ph_is_training:
                                                       True,
                                                   })

                        dist_step += 1
                        # noinspection PyProtectedMember
                        agent._maybe_save(dist_step, env_steps)
                        losses.append(result[0])

                        if with_summaries:
                            summary = result[-1]
                            agent.summary_writer.add_summary(
                                summary, global_step=env_steps)

                    # check loss improvement at the end of each epoch, early stop if necessary
                    avg_loss = np.mean(losses)
                    if avg_loss >= prev_loss:
                        log.info(
                            'Early stopping after %d epochs because distance net did not improve',
                            epoch + 1)
                        log.info('Was %.4f now %.4f, ratio %.3f', prev_loss,
                                 avg_loss, avg_loss / prev_loss)
                        break
                    prev_loss = avg_loss

        return dist_step
    def sample(self, proc_idx):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        timing = Timing()

        psutil.Process().nice(10)

        num_envs = len(DMLAB30_LEVELS_THAT_USE_LEVEL_CACHE)
        assert self.cfg.num_workers % num_envs == 0, f'should have an integer number of workers per env, e.g. {1 * num_envs}, {2 * num_envs}, etc...'
        assert self.cfg.num_envs_per_worker == 1, 'use populate_cache with 1 env per worker'

        with timing.timeit('env_init'):
            env_key = 'env'
            env_desired_num_levels = 0

            global_env_id = proc_idx * self.cfg.num_envs_per_worker
            env_config = AttrDict(worker_index=proc_idx, vector_index=0, env_id=global_env_id)
            env = create_env(self.cfg.env, cfg=self.cfg, env_config=env_config)
            env.seed(global_env_id)

            # this is to track the performance for individual DMLab levels
            if hasattr(env.unwrapped, 'level_name'):
                env_key = env.unwrapped.level_name
                env_level = env.unwrapped.level

                approx_num_episodes_per_1b_frames = DMLAB30_APPROX_NUM_EPISODES_PER_BILLION_FRAMES[
                    env_key]
                num_billions = DESIRED_TRAINING_LENGTH / int(1e9)
                num_workers_for_env = self.cfg.num_workers // num_envs
                env_desired_num_levels = int(
                    (approx_num_episodes_per_1b_frames * num_billions) / num_workers_for_env)

                env_num_levels_generated = len(dmlab_level_cache.DMLAB_GLOBAL_LEVEL_CACHE[0].
                                               all_seeds[env_level]) // num_workers_for_env

                log.warning('Worker %d (env %s) generated %d/%d levels!',
                            proc_idx,
                            env_key,
                            env_num_levels_generated,
                            env_desired_num_levels)
                time.sleep(4)

            env.reset()
            env_uses_level_cache = env.unwrapped.env_uses_level_cache

            self.report_queue.put(dict(proc_idx=proc_idx, finished_reset=True))

        self.start_event.wait()

        try:
            with timing.timeit('work'):
                last_report = last_report_frames = total_env_frames = 0
                while not self.terminate.value and total_env_frames < self.cfg.sample_env_frames_per_worker:
                    action = env.action_space.sample()
                    with timing.add_time(f'{env_key}.step'):
                        env.step(action)

                    total_env_frames += 1

                    with timing.add_time(f'{env_key}.reset'):
                        env.reset()
                        env_num_levels_generated += 1
                        log.debug('Env %s done %d/%d resets',
                                  env_key,
                                  env_num_levels_generated,
                                  env_desired_num_levels)

                    if env_num_levels_generated >= env_desired_num_levels:
                        log.debug('%s finished %d/%d resets, sleeping...',
                                  env_key,
                                  env_num_levels_generated,
                                  env_desired_num_levels)
                        time.sleep(30)  # free up CPU time for other envs

                    # if env does not use level cache, there is no need to run it
                    # let other workers proceed
                    if not env_uses_level_cache:
                        log.debug('Env %s does not require cache, sleeping...', env_key)
                        time.sleep(200)

                    with timing.add_time('report'):
                        now = time.time()
                        if now - last_report > self.report_every_sec:
                            last_report = now
                            frames_since_last_report = total_env_frames - last_report_frames
                            last_report_frames = total_env_frames
                            self.report_queue.put(
                                dict(proc_idx=proc_idx, env_frames=frames_since_last_report))

                            if get_free_disk_space_mb(self.cfg) < 3 * 1024:
                                log.error('Not enough disk space! %d',
                                          get_free_disk_space_mb(self.cfg))
                                time.sleep(200)
        except:
            log.exception('Unknown exception')
            log.error('Unknown exception in worker %d, terminating...', proc_idx)
            self.report_queue.put(dict(proc_idx=proc_idx, crash=True))

        time.sleep(proc_idx * 0.1 + 0.1)
        log.info('Process %d finished sampling. Timing: %s', proc_idx, timing)

        env.close()
Beispiel #9
0
    def localize(
        self,
        session,
        obs,
        info,
        maps,
        distance_net,
        frames=None,
        on_new_landmark=None,
        on_new_edge=None,
        timing=None,
    ):
        num_envs = len(obs)
        closest_landmark_idx = [-1] * num_envs
        # closest distance to the landmark in the existing graph (excluding new landmarks)
        closest_landmark_dist = [math.inf] * num_envs

        if all(m is None for m in maps):
            return closest_landmark_dist

        if timing is None:
            timing = Timing()

        # create a batch of all neighborhood observations from all envs for fast processing on GPU
        neighborhood_obs, neighborhood_hashes, current_obs, current_obs_hashes = [], [], [], []
        neighborhood_infos, current_infos = [], []
        total_num_neighbors = 0
        neighborhood_sizes = [0] * len(maps)
        for env_i, m in enumerate(maps):
            if m is None:
                continue

            neighbor_indices = m.neighborhood()
            neighborhood_sizes[env_i] = len(neighbor_indices)
            neighborhood_obs.extend(
                [m.get_observation(i) for i in neighbor_indices])
            neighborhood_infos.extend(
                [m.get_info(i) for i in neighbor_indices])
            neighborhood_hashes.extend(
                [m.get_hash(i) for i in neighbor_indices])
            current_obs.extend([obs[env_i]] * len(neighbor_indices))
            current_obs_hashes.extend([hash_observation(obs[env_i])] *
                                      len(neighbor_indices))
            current_infos.extend([info[env_i]] * len(neighbor_indices))
            total_num_neighbors += len(neighbor_indices)

        assert len(neighborhood_obs) == len(current_obs)
        assert len(neighborhood_obs) == len(neighborhood_hashes)
        assert len(current_obs) == total_num_neighbors
        assert len(neighborhood_infos) == len(current_infos)

        with timing.add_time('neighbor_dist'):
            distances = distance_net.distances_from_obs(
                session,
                obs_first=neighborhood_obs,
                obs_second=current_obs,
                hashes_first=neighborhood_hashes,
                hashes_second=current_obs_hashes,
                infos_first=neighborhood_infos,
                infos_second=current_infos,
            )

        assert len(distances) == total_num_neighbors

        new_landmark_candidates = []

        j = 0
        for env_i, m in enumerate(maps):
            if m is None:
                continue

            neighbor_indices = m.neighborhood()
            j_next = j + len(neighbor_indices)
            distance = distances[j:j_next]

            if len(neighbor_indices) != neighborhood_sizes[env_i]:
                log.warning(
                    'For env %d neighbors size expected %d, actual %d',
                    env_i,
                    neighborhood_sizes[env_i],
                    len(neighbor_indices),
                )

            assert len(neighbor_indices) == neighborhood_sizes[env_i]

            self._log_distances(env_i, neighbor_indices, distance)

            j = j_next

            # check if we're far enough from all landmarks in the neighborhood
            min_d, min_d_idx = min_with_idx(distance)
            closest_landmark_idx[env_i] = neighbor_indices[min_d_idx]
            closest_landmark_dist[env_i] = min_d

            if min_d >= self.new_landmark_threshold:
                # we're far enough from all obs in the neighborhood, might have found something new!
                new_landmark_candidates.append(env_i)
            else:
                # we're still sufficiently close to our neighborhood, but maybe "current landmark" has changed
                m.new_landmark_candidate_frames = 0
                m.loop_closure_candidate_frames = 0

                # crude localization
                if all(lm == closest_landmark_idx[env_i]
                       for lm in m.closest_landmarks[-self.localize_frames:]):
                    if closest_landmark_idx[env_i] != m.curr_landmark_idx:
                        m.set_curr_landmark(closest_landmark_idx[env_i])

        del neighborhood_obs
        del neighborhood_infos
        del neighborhood_hashes
        del current_obs
        del current_infos

        # Agents in some environments discovered landmarks that are far away from all landmarks in the immediate
        # vicinity. There are two possibilities:
        # 1) A new landmark should be created and added to the graph
        # 2) We're close to some other vertex in the graph - we've found a "loop closure", a new edge in a graph

        non_neighborhood_obs, non_neighborhood_hashes = [], []
        non_neighborhoods = {}
        current_obs, current_obs_hashes = [], []
        non_neighborhood_infos, current_infos = [], []
        for env_i in new_landmark_candidates:
            m = maps[env_i]
            if m is None:
                continue

            non_neighbor_indices = m.curr_non_neighbors()
            non_neighborhoods[env_i] = non_neighbor_indices
            non_neighborhood_obs.extend(
                [m.get_observation(i) for i in non_neighbor_indices])
            non_neighborhood_infos.extend(
                [m.get_info(i) for i in non_neighbor_indices])
            non_neighborhood_hashes.extend(
                [m.get_hash(i) for i in non_neighbor_indices])
            current_obs.extend([obs[env_i]] * len(non_neighbor_indices))
            current_obs_hashes.extend([hash_observation(obs[env_i])] *
                                      len(non_neighbor_indices))
            current_infos.extend([info[env_i]] * len(non_neighbor_indices))

        assert len(non_neighborhood_obs) == len(current_obs)
        assert len(non_neighborhood_obs) == len(non_neighborhood_hashes)

        with timing.add_time('non_neigh'):
            # calculate distance for all non-neighbors
            distances = []
            batch_size = 1024
            for i in range(0, len(non_neighborhood_obs), batch_size):
                start, end = i, i + batch_size

                distances_batch = distance_net.distances_from_obs(
                    session,
                    obs_first=non_neighborhood_obs[start:end],
                    obs_second=current_obs[start:end],
                    hashes_first=non_neighborhood_hashes[start:end],
                    hashes_second=current_obs_hashes[start:end],
                    infos_first=non_neighborhood_infos[start:end],
                    infos_second=current_infos[start:end],
                )
                distances.extend(distances_batch)

        j = 0
        for env_i in new_landmark_candidates:
            m = maps[env_i]
            if m is None:
                continue

            non_neighbor_indices = non_neighborhoods[env_i]
            j_next = j + len(non_neighbor_indices)
            distance = distances[j:j_next]
            j = j_next

            min_d, min_d_idx = math.inf, math.inf
            if len(distance) > 0:
                min_d, min_d_idx = min_with_idx(distance)
                closest_landmark_dist[env_i] = min(
                    closest_landmark_dist[env_i], min_d)

            if min_d < self.loop_closure_threshold:
                # current observation is close to some other landmark, "close the loop" by creating a new edge
                m.new_landmark_candidate_frames = 0
                m.loop_closure_candidate_frames += 1

                closest_landmark_idx[env_i] = non_neighbor_indices[min_d_idx]

                # crude localization
                if m.loop_closure_candidate_frames >= self.localize_frames:
                    if all(lm == closest_landmark_idx[env_i] for lm in
                           m.closest_landmarks[-self.localize_frames:]):
                        # we found a new edge! Cool!
                        m.loop_closure_candidate_frames = 0
                        m.set_curr_landmark(closest_landmark_idx[env_i])

                        if on_new_edge is not None:
                            on_new_edge(env_i)

            elif min_d >= self.new_landmark_threshold:
                m.loop_closure_candidate_frames = 0
                m.new_landmark_candidate_frames += 1

                # vertex is relatively far away from all vertices in the graph, we've found a new landmark!
                if m.new_landmark_candidate_frames >= self.localize_frames:
                    new_landmark_idx = m.add_landmark(
                        obs[env_i], info[env_i], update_curr_landmark=True)

                    if frames is not None:
                        m.graph.nodes[new_landmark_idx]['added_at'] = frames[
                            env_i]

                    closest_landmark_idx[env_i] = new_landmark_idx
                    m.new_landmark_candidate_frames = 0

                    if on_new_landmark is not None:
                        on_new_landmark(env_i, new_landmark_idx)
            else:
                m.new_landmark_candidate_frames = 0
                m.loop_closure_candidate_frames = 0

        # update localization info
        for env_i in range(num_envs):
            m = maps[env_i]
            if m is None:
                continue

            assert closest_landmark_idx[env_i] >= 0
            m.closest_landmarks.append(closest_landmark_idx[env_i])

        # # visualize "closest" landmark
        # import cv2
        # closest_lm = maps[0].closest_landmarks[-1]
        # closest_obs = maps[0].get_observation(closest_lm)
        # cv2.imshow('closest_obs', cv2.resize(cv2.cvtColor(closest_obs, cv2.COLOR_RGB2BGR), (420, 420)))
        # cv2.waitKey(1)

        return closest_landmark_dist
Beispiel #10
0
    def _learn_loop(self, multi_env):
        """Main training loop."""
        # env_steps used in tensorboard (and thus, our results)
        # actor_step used as global step for training
        step, env_steps = self.session.run(
            [self.actor_step, self.total_env_steps])

        env_obs = multi_env.reset()
        obs, goals = main_observation(env_obs), goal_observation(env_obs)

        buffer = CuriousPPOBuffer()
        trajectory_buffer = TrajectoryBuffer(self.params.num_envs)
        self.curiosity.set_trajectory_buffer(trajectory_buffer)

        def end_of_training(s, es):
            return s >= self.params.train_for_steps or es > self.params.train_for_env_steps

        while not end_of_training(step, env_steps):
            timing = Timing()
            num_steps = 0
            batch_start = time.time()

            buffer.reset()

            with timing.timeit('experience'):
                # collecting experience
                for rollout_step in range(self.params.rollout):
                    actions, action_probs, values = self._policy_step(
                        obs, goals)

                    # wait for all the workers to complete an environment step
                    env_obs, rewards, dones, infos = multi_env.step(actions)

                    if self.params.graceful_episode_termination:
                        rewards = list(rewards)
                        for i in range(self.params.num_envs):
                            if dones[i] and infos[i].get('prev') is not None:
                                if infos[i]['prev'].get(
                                        'terminated_by_timer', False):
                                    log.info('Env %d terminated by timer', i)
                                    rewards[i] += values[i]

                    if not self.params.random_exploration:
                        trajectory_buffer.add(obs, actions, infos, dones)

                    next_obs, new_goals = main_observation(
                        env_obs), goal_observation(env_obs)

                    # calculate curiosity bonus
                    with timing.add_time('curiosity'):
                        if not self.params.random_exploration:
                            bonuses = self.curiosity.generate_bonus_rewards(
                                self.session,
                                obs,
                                next_obs,
                                actions,
                                dones,
                                infos,
                            )
                            rewards = self.params.extrinsic_reward_coeff * np.array(
                                rewards) + bonuses

                    # add experience from environment to the current buffer
                    buffer.add(obs, next_obs, actions, action_probs, rewards,
                               dones, values, goals)

                    obs, goals = next_obs, new_goals
                    self.process_infos(infos)
                    num_steps += num_env_steps(infos)

                # last step values are required for TD-return calculation
                _, _, values = self._policy_step(obs, goals)
                buffer.values.append(values)

            env_steps += num_steps

            # calculate discounted returns and GAE
            buffer.finalize_batch(self.params.gamma, self.params.gae_lambda)

            # update actor and critic and CM
            with timing.timeit('train'):
                step = self._train_with_curiosity(step, buffer, env_steps,
                                                  timing)

            avg_reward = multi_env.calc_avg_rewards(
                n=self.params.stats_episodes)
            avg_length = multi_env.calc_avg_episode_lengths(
                n=self.params.stats_episodes)

            self._maybe_update_avg_reward(avg_reward,
                                          multi_env.stats_num_episodes())
            self._maybe_trajectory_summaries(trajectory_buffer, env_steps)
            self._maybe_coverage_summaries(env_steps)
            self.curiosity.additional_summaries(
                env_steps,
                self.summary_writer,
                self.params.stats_episodes,
                map_img=self.map_img,
                coord_limits=self.coord_limits,
            )

            trajectory_buffer.reset_trajectories()

            fps = num_steps / (time.time() - batch_start)
            self._maybe_print(step, env_steps, avg_reward, avg_length, fps,
                              timing)
            self._maybe_aux_summaries(env_steps, avg_reward, avg_length, fps)